feat(custom_logger.py): expose new async_dataset_hook for modifying… (#6331)

* feat(custom_logger.py): expose new `async_dataset_hook` for modifying/rejecting argilla items before logging

Allows user more control on what gets logged to argilla for annotations

* feat(google_ai_studio_endpoints.py): add new `/azure/*` pass through route

enables pass-through for azure provider

* feat(utils.py): support checking ollama `/api/show` endpoint for retrieving ollama model info

Fixes https://github.com/BerriAI/litellm/issues/6322

* fix(user_api_key_auth.py): add `/key/delete` to an allowed_ui_routes

Fixes https://github.com/BerriAI/litellm/issues/6236

* fix(user_api_key_auth.py): remove type ignore

* fix(user_api_key_auth.py): route ui vs. api token checks differently

Fixes https://github.com/BerriAI/litellm/issues/6238

* feat(internal_user_endpoints.py): support setting models as a default internal user param

Closes https://github.com/BerriAI/litellm/issues/6239

* fix(user_api_key_auth.py): fix exception string

* fix(user_api_key_auth.py): fix error string

* fix: fix test
This commit is contained in:
Krish Dholakia 2024-10-20 09:00:04 -07:00 committed by GitHub
parent 7cc12bd5c6
commit 905ebeb924
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 422 additions and 153 deletions

View file

@ -207,6 +207,7 @@ litellm_settings:
user_role: "internal_user" # one of "internal_user", "internal_user_viewer", "proxy_admin", "proxy_admin_viewer". New SSO users not in litellm will be created as this user user_role: "internal_user" # one of "internal_user", "internal_user_viewer", "proxy_admin", "proxy_admin_viewer". New SSO users not in litellm will be created as this user
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on

View file

@ -21,53 +21,22 @@ from pydantic import BaseModel # type: ignore
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
get_async_httpx_client, get_async_httpx_client,
httpxSpecialProvider, httpxSpecialProvider,
) )
from litellm.llms.prompt_templates.common_utils import get_content_from_model_response from litellm.llms.prompt_templates.common_utils import get_content_from_model_response
from litellm.types.integrations.argilla import (
SUPPORTED_PAYLOAD_FIELDS,
ArgillaCredentialsObject,
ArgillaItem,
ArgillaPayload,
)
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
class LangsmithInputs(BaseModel):
model: Optional[str] = None
messages: Optional[List[Any]] = None
stream: Optional[bool] = None
call_type: Optional[str] = None
litellm_call_id: Optional[str] = None
completion_start_time: Optional[datetime] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
custom_llm_provider: Optional[str] = None
input: Optional[List[Any]] = None
log_event_type: Optional[str] = None
original_response: Optional[Any] = None
response_cost: Optional[float] = None
# LiteLLM Virtual Key specific fields
user_api_key: Optional[str] = None
user_api_key_user_id: Optional[str] = None
user_api_key_team_alias: Optional[str] = None
class ArgillaItem(TypedDict):
fields: Dict[str, Any]
class ArgillaPayload(TypedDict):
items: List[ArgillaItem]
class ArgillaCredentialsObject(TypedDict):
ARGILLA_API_KEY: str
ARGILLA_DATASET_NAME: str
ARGILLA_BASE_URL: str
SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"]
def is_serializable(value): def is_serializable(value):
non_serializable_types = ( non_serializable_types = (
types.CoroutineType, types.CoroutineType,
@ -215,7 +184,7 @@ class ArgillaLogger(CustomBatchLogger):
def _prepare_log_data( def _prepare_log_data(
self, kwargs, response_obj, start_time, end_time self, kwargs, response_obj, start_time, end_time
) -> ArgillaItem: ) -> Optional[ArgillaItem]:
try: try:
# Ensure everything in the payload is converted to str # Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get( payload: Optional[StandardLoggingPayload] = kwargs.get(
@ -235,6 +204,7 @@ class ArgillaLogger(CustomBatchLogger):
argilla_item["fields"][k] = argilla_response argilla_item["fields"][k] = argilla_response
else: else:
argilla_item["fields"][k] = payload.get(v, None) argilla_item["fields"][k] = payload.get(v, None)
return argilla_item return argilla_item
except Exception: except Exception:
raise raise
@ -294,6 +264,9 @@ class ArgillaLogger(CustomBatchLogger):
response_obj, response_obj,
) )
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
if data is None:
return
self.log_queue.append(data) self.log_queue.append(data)
verbose_logger.debug( verbose_logger.debug(
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..." f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
@ -321,7 +294,25 @@ class ArgillaLogger(CustomBatchLogger):
kwargs, kwargs,
response_obj, response_obj,
) )
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING
for callback in litellm.callbacks:
if isinstance(callback, CustomLogger):
try:
if data is None:
break
data = await callback.async_dataset_hook(data, payload)
except NotImplementedError:
pass
if data is None:
return
self.log_queue.append(data) self.log_queue.append(data)
verbose_logger.debug( verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s", "Langsmith logging: queue length %s, batch size %s",

View file

@ -10,6 +10,7 @@ from pydantic import BaseModel
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.integrations.argilla import ArgillaItem
from litellm.types.llms.openai import ChatCompletionRequest from litellm.types.llms.openai import ChatCompletionRequest
from litellm.types.services import ServiceLoggerPayload from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import ( from litellm.types.utils import (
@ -17,6 +18,7 @@ from litellm.types.utils import (
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
ModelResponse, ModelResponse,
StandardLoggingPayload,
) )
@ -108,6 +110,20 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
""" """
pass pass
### DATASET HOOKS #### - currently only used for Argilla
async def async_dataset_hook(
self,
logged_item: ArgillaItem,
standard_logging_payload: Optional[StandardLoggingPayload],
) -> Optional[ArgillaItem]:
"""
- Decide if the result should be logged to Argilla.
- Modify the result before logging to Argilla.
- Return None if the result should not be logged to Argilla.
"""
raise NotImplementedError("async_dataset_hook not implemented")
#### CALL HOOKS - proxy only #### #### CALL HOOKS - proxy only ####
""" """
Control the modify incoming / outgoung data before calling the model Control the modify incoming / outgoung data before calling the model

View file

@ -14,7 +14,8 @@ import requests # type: ignore
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.types.utils import ProviderField, StreamingChoices from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices
from .prompt_templates.factory import custom_prompt, prompt_factory from .prompt_templates.factory import custom_prompt, prompt_factory
@ -163,6 +164,56 @@ class OllamaConfig:
"response_format", "response_format",
] ]
def _supports_function_calling(self, ollama_model_info: dict) -> bool:
"""
Check if the 'template' field in the ollama_model_info contains a 'tools' or 'function' key.
"""
_template: str = str(ollama_model_info.get("template", "") or "")
return "tools" in _template.lower()
def _get_max_tokens(self, ollama_model_info: dict) -> Optional[int]:
_model_info: dict = ollama_model_info.get("model_info", {})
for k, v in _model_info.items():
if "context_length" in k:
return v
return None
def get_model_info(self, model: str) -> ModelInfo:
"""
curl http://localhost:11434/api/show -d '{
"name": "mistral"
}'
"""
api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434"
try:
response = litellm.module_level_client.post(
url=f"{api_base}/api/show",
json={"name": model},
)
except Exception as e:
raise Exception(
f"OllamaError: Error getting model info for {model}. Set Ollama API Base via `OLLAMA_API_BASE` environment variable. Error: {e}"
)
model_info = response.json()
_max_tokens: Optional[int] = self._get_max_tokens(model_info)
return ModelInfo(
key=model,
litellm_provider="ollama",
mode="chat",
supported_openai_params=self.get_supported_openai_params(),
supports_function_calling=self._supports_function_calling(model_info),
input_cost_per_token=0.0,
output_cost_per_token=0.0,
max_tokens=_max_tokens,
max_input_tokens=_max_tokens,
max_output_tokens=_max_tokens,
)
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI # ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
# and convert to jpeg if necessary. # and convert to jpeg if necessary.

View file

@ -2,4 +2,4 @@ model_list:
- model_name: "gpt-4o-audio-preview" - model_name: "gpt-4o-audio-preview"
litellm_params: litellm_params:
model: gpt-4o-audio-preview model: gpt-4o-audio-preview
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY

View file

@ -335,6 +335,31 @@ class LiteLLMRoutes(enum.Enum):
"/metrics", "/metrics",
] ]
ui_routes = [
"/sso",
"/sso/get/ui_settings",
"/login",
"/key/generate",
"/key/update",
"/key/info",
"/key/delete",
"/config",
"/spend",
"/user",
"/model/info",
"/v2/model/info",
"/v2/key/info",
"/models",
"/v1/models",
"/global/spend",
"/global/spend/logs",
"/global/spend/keys",
"/global/spend/models",
"/global/predict/spend/logs",
"/global/activity",
"/health/services",
] + info_routes
internal_user_routes = ( internal_user_routes = (
[ [
"/key/generate", "/key/generate",

View file

@ -105,6 +105,88 @@ def _get_bearer_token(
return api_key return api_key
def _is_ui_route_allowed(
route: str,
user_obj: Optional[LiteLLM_UserTable] = None,
) -> bool:
"""
- Route b/w ui token check and normal token check
"""
# this token is only used for managing the ui
allowed_routes = LiteLLMRoutes.ui_routes.value
# check if the current route startswith any of the allowed routes
if (
route is not None
and isinstance(route, str)
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
):
# Do something if the current route starts with any of the allowed routes
return True
else:
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
return True
elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value:
return True
else:
raise Exception(
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
)
def _is_api_route_allowed(
route: str,
request: Request,
request_data: dict,
api_key: str,
valid_token: Optional[UserAPIKeyAuth],
user_obj: Optional[LiteLLM_UserTable] = None,
) -> bool:
"""
- Route b/w api token check and normal token check
"""
_user_role = _get_user_role(user_obj=user_obj)
if valid_token is None:
raise Exception("Invalid proxy server token passed")
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
non_proxy_admin_allowed_routes_check(
user_obj=user_obj,
_user_role=_user_role,
route=route,
request=request,
request_data=request_data,
api_key=api_key,
valid_token=valid_token,
)
return True
def _is_allowed_route(
route: str,
token_type: Literal["ui", "api"],
request: Request,
request_data: dict,
api_key: str,
valid_token: Optional[UserAPIKeyAuth],
user_obj: Optional[LiteLLM_UserTable] = None,
) -> bool:
"""
- Route b/w ui token check and normal token check
"""
if token_type == "ui":
return _is_ui_route_allowed(route=route, user_obj=user_obj)
else:
return _is_api_route_allowed(
route=route,
request=request,
request_data=request_data,
api_key=api_key,
valid_token=valid_token,
user_obj=user_obj,
)
async def user_api_key_auth( # noqa: PLR0915 async def user_api_key_auth( # noqa: PLR0915
request: Request, request: Request,
api_key: str = fastapi.Security(api_key_header), api_key: str = fastapi.Security(api_key_header),
@ -1041,81 +1123,27 @@ async def user_api_key_auth( # noqa: PLR0915
if _end_user_object is not None: if _end_user_object is not None:
valid_token_dict.update(end_user_params) valid_token_dict.update(end_user_params)
_user_role = _get_user_role(user_obj=user_obj)
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
non_proxy_admin_allowed_routes_check(
user_obj=user_obj,
_user_role=_user_role,
route=route,
request=request,
request_data=request_data,
api_key=api_key,
valid_token=valid_token,
)
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions # check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
# sso/login, ui/login, /key functions and /user functions # sso/login, ui/login, /key functions and /user functions
# this will never be allowed to call /chat/completions # this will never be allowed to call /chat/completions
token_team = getattr(valid_token, "team_id", None) token_team = getattr(valid_token, "team_id", None)
token_type: Literal["ui", "api"] = (
"ui"
if token_team is not None and token_team == "litellm-dashboard"
else "api"
)
_is_route_allowed = _is_allowed_route(
route=route,
token_type=token_type,
user_obj=user_obj,
request=request,
request_data=request_data,
api_key=api_key,
valid_token=valid_token,
)
if not _is_route_allowed:
raise HTTPException(401, detail="Invalid route for UI token")
if token_team is not None and token_team == "litellm-dashboard":
# this token is only used for managing the ui
allowed_routes = [
"/sso",
"/sso/get/ui_settings",
"/login",
"/key/generate",
"/key/update",
"/key/info",
"/config",
"/spend",
"/user",
"/model/info",
"/v2/model/info",
"/v2/key/info",
"/models",
"/v1/models",
"/global/spend",
"/global/spend/logs",
"/global/spend/keys",
"/global/spend/models",
"/global/predict/spend/logs",
"/global/activity",
"/health/services",
] + LiteLLMRoutes.info_routes.value # type: ignore
# check if the current route startswith any of the allowed routes
if (
route is not None
and isinstance(route, str)
and any(
route.startswith(allowed_route) for allowed_route in allowed_routes
)
):
# Do something if the current route starts with any of the allowed routes
pass
else:
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
return UserAPIKeyAuth(
api_key=api_key,
user_role=LitellmUserRoles.PROXY_ADMIN,
parent_otel_span=parent_otel_span,
**valid_token_dict,
)
elif (
_has_user_setup_sso()
and route in LiteLLMRoutes.sso_only_routes.value
):
return UserAPIKeyAuth(
api_key=api_key,
user_role=_user_role, # type: ignore
parent_otel_span=parent_otel_span,
**valid_token_dict,
)
else:
raise Exception(
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
)
if valid_token is None: if valid_token is None:
# No token was found when looking up in the DB # No token was found when looking up in the DB
raise Exception("Invalid proxy server token passed") raise Exception("Invalid proxy server token passed")

View file

@ -41,6 +41,40 @@ from litellm.proxy.management_helpers.utils import (
router = APIRouter() router = APIRouter()
def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict:
if "user_id" in data_json and data_json["user_id"] is None:
data_json["user_id"] = str(uuid.uuid4())
auto_create_key = data_json.pop("auto_create_key", True)
if auto_create_key is False:
data_json["table_name"] = (
"user" # only create a user, don't create key if 'auto_create_key' set to False
)
is_internal_user = False
if data.user_role == LitellmUserRoles.INTERNAL_USER:
is_internal_user = True
if litellm.default_internal_user_params:
for key, value in litellm.default_internal_user_params.items():
if key not in data_json or data_json[key] is None:
data_json[key] = value
elif (
key == "models"
and isinstance(data_json[key], list)
and len(data_json[key]) == 0
):
data_json[key] = value
if "max_budget" in data_json and data_json["max_budget"] is None:
if is_internal_user and litellm.max_internal_user_budget is not None:
data_json["max_budget"] = litellm.max_internal_user_budget
if "budget_duration" in data_json and data_json["budget_duration"] is None:
if is_internal_user and litellm.internal_user_budget_duration is not None:
data_json["budget_duration"] = litellm.internal_user_budget_duration
return data_json
@router.post( @router.post(
"/user/new", "/user/new",
tags=["Internal User management"], tags=["Internal User management"],
@ -94,26 +128,7 @@ async def new_user(
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
data_json = data.json() # type: ignore data_json = data.json() # type: ignore
if "user_id" in data_json and data_json["user_id"] is None: data_json = _update_internal_user_params(data_json, data)
data_json["user_id"] = str(uuid.uuid4())
auto_create_key = data_json.pop("auto_create_key", True)
if auto_create_key is False:
data_json["table_name"] = (
"user" # only create a user, don't create key if 'auto_create_key' set to False
)
is_internal_user = False
if data.user_role == LitellmUserRoles.INTERNAL_USER:
is_internal_user = True
if "max_budget" in data_json and data_json["max_budget"] is None:
if is_internal_user and litellm.max_internal_user_budget is not None:
data_json["max_budget"] = litellm.max_internal_user_budget
if "budget_duration" in data_json and data_json["budget_duration"] is None:
if is_internal_user and litellm.internal_user_budget_duration is not None:
data_json["budget_duration"] = litellm.internal_user_budget_duration
response = await generate_key_helper_fn(request_type="user", **data_json) response = await generate_key_helper_fn(request_type="user", **data_json)
# Admin UI Logic # Admin UI Logic

View file

@ -1585,10 +1585,6 @@ class ProxyConfig:
printed_yaml = copy.deepcopy(config) printed_yaml = copy.deepcopy(config)
printed_yaml.pop("environment_variables", None) printed_yaml.pop("environment_variables", None)
verbose_proxy_logger.debug(
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
)
config = self._check_for_os_environ_vars(config=config) config = self._check_for_os_environ_vars(config=config)
return config return config

View file

@ -40,6 +40,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route, create_pass_through_route,
) )
from litellm.secret_managers.main import get_secret_str
router = APIRouter() router = APIRouter()
default_vertex_config = None default_vertex_config = None
@ -226,3 +227,53 @@ async def bedrock_proxy_route(
) )
return received_value return received_value
@router.api_route("/azure/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def azure_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
if base_target_url is None:
raise Exception(
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
)
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
azure_api_key = get_secret_str(secret_name="AZURE_API_KEY")
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers={
"authorization": "Bearer {}".format(azure_api_key),
"api-key": "{}".format(azure_api_key),
},
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
query_params=dict(request.query_params), # type: ignore
)
return received_value

View file

@ -0,0 +1,21 @@
import os
from datetime import datetime as dt
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict
class ArgillaItem(TypedDict):
fields: Dict[str, Any]
class ArgillaPayload(TypedDict):
items: List[ArgillaItem]
class ArgillaCredentialsObject(TypedDict):
ARGILLA_API_KEY: str
ARGILLA_DATASET_NAME: str
ARGILLA_BASE_URL: str
SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"]

View file

@ -1821,6 +1821,7 @@ def supports_function_calling(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
## CHECK IF MODEL SUPPORTS FUNCTION CALLING ##
model_info = litellm.get_model_info( model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
@ -4768,6 +4769,8 @@ def get_model_info( # noqa: PLR0915
supports_assistant_prefill=None, supports_assistant_prefill=None,
supports_prompt_caching=None, supports_prompt_caching=None,
) )
elif custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
return litellm.OllamaConfig().get_model_info(model)
else: else:
""" """
Check if: (in order of specificity) Check if: (in order of specificity)
@ -4964,7 +4967,9 @@ def get_model_info( # noqa: PLR0915
supports_audio_input=_model_info.get("supports_audio_input", False), supports_audio_input=_model_info.get("supports_audio_input", False),
supports_audio_output=_model_info.get("supports_audio_output", False), supports_audio_output=_model_info.get("supports_audio_output", False),
) )
except Exception: except Exception as e:
if "OllamaError" in str(e):
raise e
raise Exception( raise Exception(
"This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json.".format( "This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json.".format(
model, custom_llm_provider model, custom_llm_provider

View file

@ -11,6 +11,7 @@ import pytest
import litellm import litellm
from litellm import get_model_info from litellm import get_model_info
from unittest.mock import AsyncMock, MagicMock, patch
def test_get_model_info_simple_model_name(): def test_get_model_info_simple_model_name():
@ -74,3 +75,25 @@ def test_get_model_info_gemini_pro():
info = litellm.get_model_info("gemini-1.5-pro-002") info = litellm.get_model_info("gemini-1.5-pro-002")
print("info", info) print("info", info)
assert info["key"] == "gemini-1.5-pro-002" assert info["key"] == "gemini-1.5-pro-002"
def test_get_model_info_ollama_chat():
from litellm.llms.ollama import OllamaConfig
with patch.object(
litellm.module_level_client,
"post",
return_value=MagicMock(
json=lambda: {
"model_info": {"llama.context_length": 32768},
"template": "tools",
}
),
):
info = OllamaConfig().get_model_info("mistral")
print("info", info)
assert info["supports_function_calling"] is True
info = get_model_info("ollama/mistral")
print("info", info)
assert info["supports_function_calling"] is True

View file

@ -406,3 +406,29 @@ def test_add_litellm_data_for_backend_llm_call(headers, expected_data):
data = add_litellm_data_for_backend_llm_call(headers) data = add_litellm_data_for_backend_llm_call(headers)
assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True) assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True)
def test_update_internal_user_params():
from litellm.proxy.management_endpoints.internal_user_endpoints import (
_update_internal_user_params,
)
from litellm.proxy._types import NewUserRequest
litellm.default_internal_user_params = {
"max_budget": 100,
"budget_duration": "30d",
"models": ["gpt-3.5-turbo"],
}
data = NewUserRequest(user_role="internal_user", user_email="krrish3@berri.ai")
data_json = data.model_dump()
updated_data_json = _update_internal_user_params(data_json, data)
assert updated_data_json["models"] == litellm.default_internal_user_params["models"]
assert (
updated_data_json["max_budget"]
== litellm.default_internal_user_params["max_budget"]
)
assert (
updated_data_json["budget_duration"]
== litellm.default_internal_user_params["budget_duration"]
)

View file

@ -291,3 +291,28 @@ async def test_auth_with_allowed_routes(route, should_raise_error):
await user_api_key_auth(request=request, api_key="Bearer " + user_key) await user_api_key_auth(request=request, api_key="Bearer " + user_key)
setattr(proxy_server, "general_settings", initial_general_settings) setattr(proxy_server, "general_settings", initial_general_settings)
@pytest.mark.parametrize("route", ["/global/spend/logs", "/key/delete"])
def test_is_ui_route_allowed(route):
from litellm.proxy.auth.user_api_key_auth import _is_ui_route_allowed
from litellm.proxy._types import LiteLLM_UserTable
received_args: dict = {
"route": route,
"user_obj": LiteLLM_UserTable(
user_id="3b803c0e-666e-4e99-bd5c-6e534c07e297",
max_budget=None,
spend=0.0,
model_max_budget={},
model_spend={},
user_email="my-test-email@1234.com",
models=[],
tpm_limit=None,
rpm_limit=None,
user_role="internal_user",
organization_memberships=[],
),
}
assert _is_ui_route_allowed(**received_args)

View file

@ -448,24 +448,19 @@ def test_token_counter():
# test_token_counter() # test_token_counter()
def test_supports_function_calling(): @pytest.mark.parametrize(
"model, expected_bool",
[
("gpt-3.5-turbo", True),
("azure/gpt-4-1106-preview", True),
("groq/gemma-7b-it", True),
("anthropic.claude-instant-v1", False),
("palm/chat-bison", False),
],
)
def test_supports_function_calling(model, expected_bool):
try: try:
assert litellm.supports_function_calling(model="gpt-3.5-turbo") == True assert litellm.supports_function_calling(model=model) == expected_bool
assert (
litellm.supports_function_calling(model="azure/gpt-4-1106-preview") == True
)
assert litellm.supports_function_calling(model="groq/gemma-7b-it") == True
assert (
litellm.supports_function_calling(model="anthropic.claude-instant-v1")
== False
)
assert litellm.supports_function_calling(model="palm/chat-bison") == False
assert litellm.supports_function_calling(model="ollama/llama2") == False
assert (
litellm.supports_function_calling(model="anthropic.claude-instant-v1")
== False
)
assert litellm.supports_function_calling(model="claude-2") == False
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")