forked from phoenix/litellm-mirror
feat(custom_logger.py): expose new async_dataset_hook
for modifying… (#6331)
* feat(custom_logger.py): expose new `async_dataset_hook` for modifying/rejecting argilla items before logging Allows user more control on what gets logged to argilla for annotations * feat(google_ai_studio_endpoints.py): add new `/azure/*` pass through route enables pass-through for azure provider * feat(utils.py): support checking ollama `/api/show` endpoint for retrieving ollama model info Fixes https://github.com/BerriAI/litellm/issues/6322 * fix(user_api_key_auth.py): add `/key/delete` to an allowed_ui_routes Fixes https://github.com/BerriAI/litellm/issues/6236 * fix(user_api_key_auth.py): remove type ignore * fix(user_api_key_auth.py): route ui vs. api token checks differently Fixes https://github.com/BerriAI/litellm/issues/6238 * feat(internal_user_endpoints.py): support setting models as a default internal user param Closes https://github.com/BerriAI/litellm/issues/6239 * fix(user_api_key_auth.py): fix exception string * fix(user_api_key_auth.py): fix error string * fix: fix test
This commit is contained in:
parent
7cc12bd5c6
commit
905ebeb924
16 changed files with 422 additions and 153 deletions
|
@ -207,6 +207,7 @@ litellm_settings:
|
||||||
user_role: "internal_user" # one of "internal_user", "internal_user_viewer", "proxy_admin", "proxy_admin_viewer". New SSO users not in litellm will be created as this user
|
user_role: "internal_user" # one of "internal_user", "internal_user_viewer", "proxy_admin", "proxy_admin_viewer". New SSO users not in litellm will be created as this user
|
||||||
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
|
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
|
||||||
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
|
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
|
||||||
|
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
|
||||||
|
|
||||||
|
|
||||||
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
|
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
|
||||||
|
|
|
@ -21,53 +21,22 @@ from pydantic import BaseModel # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
httpxSpecialProvider,
|
httpxSpecialProvider,
|
||||||
)
|
)
|
||||||
from litellm.llms.prompt_templates.common_utils import get_content_from_model_response
|
from litellm.llms.prompt_templates.common_utils import get_content_from_model_response
|
||||||
|
from litellm.types.integrations.argilla import (
|
||||||
|
SUPPORTED_PAYLOAD_FIELDS,
|
||||||
|
ArgillaCredentialsObject,
|
||||||
|
ArgillaItem,
|
||||||
|
ArgillaPayload,
|
||||||
|
)
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
|
||||||
class LangsmithInputs(BaseModel):
|
|
||||||
model: Optional[str] = None
|
|
||||||
messages: Optional[List[Any]] = None
|
|
||||||
stream: Optional[bool] = None
|
|
||||||
call_type: Optional[str] = None
|
|
||||||
litellm_call_id: Optional[str] = None
|
|
||||||
completion_start_time: Optional[datetime] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
custom_llm_provider: Optional[str] = None
|
|
||||||
input: Optional[List[Any]] = None
|
|
||||||
log_event_type: Optional[str] = None
|
|
||||||
original_response: Optional[Any] = None
|
|
||||||
response_cost: Optional[float] = None
|
|
||||||
|
|
||||||
# LiteLLM Virtual Key specific fields
|
|
||||||
user_api_key: Optional[str] = None
|
|
||||||
user_api_key_user_id: Optional[str] = None
|
|
||||||
user_api_key_team_alias: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ArgillaItem(TypedDict):
|
|
||||||
fields: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class ArgillaPayload(TypedDict):
|
|
||||||
items: List[ArgillaItem]
|
|
||||||
|
|
||||||
|
|
||||||
class ArgillaCredentialsObject(TypedDict):
|
|
||||||
ARGILLA_API_KEY: str
|
|
||||||
ARGILLA_DATASET_NAME: str
|
|
||||||
ARGILLA_BASE_URL: str
|
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"]
|
|
||||||
|
|
||||||
|
|
||||||
def is_serializable(value):
|
def is_serializable(value):
|
||||||
non_serializable_types = (
|
non_serializable_types = (
|
||||||
types.CoroutineType,
|
types.CoroutineType,
|
||||||
|
@ -215,7 +184,7 @@ class ArgillaLogger(CustomBatchLogger):
|
||||||
|
|
||||||
def _prepare_log_data(
|
def _prepare_log_data(
|
||||||
self, kwargs, response_obj, start_time, end_time
|
self, kwargs, response_obj, start_time, end_time
|
||||||
) -> ArgillaItem:
|
) -> Optional[ArgillaItem]:
|
||||||
try:
|
try:
|
||||||
# Ensure everything in the payload is converted to str
|
# Ensure everything in the payload is converted to str
|
||||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
|
@ -235,6 +204,7 @@ class ArgillaLogger(CustomBatchLogger):
|
||||||
argilla_item["fields"][k] = argilla_response
|
argilla_item["fields"][k] = argilla_response
|
||||||
else:
|
else:
|
||||||
argilla_item["fields"][k] = payload.get(v, None)
|
argilla_item["fields"][k] = payload.get(v, None)
|
||||||
|
|
||||||
return argilla_item
|
return argilla_item
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
|
@ -294,6 +264,9 @@ class ArgillaLogger(CustomBatchLogger):
|
||||||
response_obj,
|
response_obj,
|
||||||
)
|
)
|
||||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||||
|
if data is None:
|
||||||
|
return
|
||||||
|
|
||||||
self.log_queue.append(data)
|
self.log_queue.append(data)
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
|
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
|
||||||
|
@ -321,7 +294,25 @@ class ArgillaLogger(CustomBatchLogger):
|
||||||
kwargs,
|
kwargs,
|
||||||
response_obj,
|
response_obj,
|
||||||
)
|
)
|
||||||
|
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
|
"standard_logging_object", None
|
||||||
|
)
|
||||||
|
|
||||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||||
|
|
||||||
|
## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING
|
||||||
|
for callback in litellm.callbacks:
|
||||||
|
if isinstance(callback, CustomLogger):
|
||||||
|
try:
|
||||||
|
if data is None:
|
||||||
|
break
|
||||||
|
data = await callback.async_dataset_hook(data, payload)
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
return
|
||||||
|
|
||||||
self.log_queue.append(data)
|
self.log_queue.append(data)
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"Langsmith logging: queue length %s, batch size %s",
|
"Langsmith logging: queue length %s, batch size %s",
|
||||||
|
|
|
@ -10,6 +10,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.types.integrations.argilla import ArgillaItem
|
||||||
from litellm.types.llms.openai import ChatCompletionRequest
|
from litellm.types.llms.openai import ChatCompletionRequest
|
||||||
from litellm.types.services import ServiceLoggerPayload
|
from litellm.types.services import ServiceLoggerPayload
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
|
@ -17,6 +18,7 @@ from litellm.types.utils import (
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
StandardLoggingPayload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,6 +110,20 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
### DATASET HOOKS #### - currently only used for Argilla
|
||||||
|
|
||||||
|
async def async_dataset_hook(
|
||||||
|
self,
|
||||||
|
logged_item: ArgillaItem,
|
||||||
|
standard_logging_payload: Optional[StandardLoggingPayload],
|
||||||
|
) -> Optional[ArgillaItem]:
|
||||||
|
"""
|
||||||
|
- Decide if the result should be logged to Argilla.
|
||||||
|
- Modify the result before logging to Argilla.
|
||||||
|
- Return None if the result should not be logged to Argilla.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("async_dataset_hook not implemented")
|
||||||
|
|
||||||
#### CALL HOOKS - proxy only ####
|
#### CALL HOOKS - proxy only ####
|
||||||
"""
|
"""
|
||||||
Control the modify incoming / outgoung data before calling the model
|
Control the modify incoming / outgoung data before calling the model
|
||||||
|
|
|
@ -14,7 +14,8 @@ import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.types.utils import ProviderField, StreamingChoices
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices
|
||||||
|
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
|
@ -163,6 +164,56 @@ class OllamaConfig:
|
||||||
"response_format",
|
"response_format",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _supports_function_calling(self, ollama_model_info: dict) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the 'template' field in the ollama_model_info contains a 'tools' or 'function' key.
|
||||||
|
"""
|
||||||
|
_template: str = str(ollama_model_info.get("template", "") or "")
|
||||||
|
return "tools" in _template.lower()
|
||||||
|
|
||||||
|
def _get_max_tokens(self, ollama_model_info: dict) -> Optional[int]:
|
||||||
|
_model_info: dict = ollama_model_info.get("model_info", {})
|
||||||
|
|
||||||
|
for k, v in _model_info.items():
|
||||||
|
if "context_length" in k:
|
||||||
|
return v
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_model_info(self, model: str) -> ModelInfo:
|
||||||
|
"""
|
||||||
|
curl http://localhost:11434/api/show -d '{
|
||||||
|
"name": "mistral"
|
||||||
|
}'
|
||||||
|
"""
|
||||||
|
api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = litellm.module_level_client.post(
|
||||||
|
url=f"{api_base}/api/show",
|
||||||
|
json={"name": model},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"OllamaError: Error getting model info for {model}. Set Ollama API Base via `OLLAMA_API_BASE` environment variable. Error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_info = response.json()
|
||||||
|
|
||||||
|
_max_tokens: Optional[int] = self._get_max_tokens(model_info)
|
||||||
|
|
||||||
|
return ModelInfo(
|
||||||
|
key=model,
|
||||||
|
litellm_provider="ollama",
|
||||||
|
mode="chat",
|
||||||
|
supported_openai_params=self.get_supported_openai_params(),
|
||||||
|
supports_function_calling=self._supports_function_calling(model_info),
|
||||||
|
input_cost_per_token=0.0,
|
||||||
|
output_cost_per_token=0.0,
|
||||||
|
max_tokens=_max_tokens,
|
||||||
|
max_input_tokens=_max_tokens,
|
||||||
|
max_output_tokens=_max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
|
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
|
||||||
# and convert to jpeg if necessary.
|
# and convert to jpeg if necessary.
|
||||||
|
|
|
@ -2,4 +2,4 @@ model_list:
|
||||||
- model_name: "gpt-4o-audio-preview"
|
- model_name: "gpt-4o-audio-preview"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-4o-audio-preview
|
model: gpt-4o-audio-preview
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
|
@ -335,6 +335,31 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/metrics",
|
"/metrics",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
ui_routes = [
|
||||||
|
"/sso",
|
||||||
|
"/sso/get/ui_settings",
|
||||||
|
"/login",
|
||||||
|
"/key/generate",
|
||||||
|
"/key/update",
|
||||||
|
"/key/info",
|
||||||
|
"/key/delete",
|
||||||
|
"/config",
|
||||||
|
"/spend",
|
||||||
|
"/user",
|
||||||
|
"/model/info",
|
||||||
|
"/v2/model/info",
|
||||||
|
"/v2/key/info",
|
||||||
|
"/models",
|
||||||
|
"/v1/models",
|
||||||
|
"/global/spend",
|
||||||
|
"/global/spend/logs",
|
||||||
|
"/global/spend/keys",
|
||||||
|
"/global/spend/models",
|
||||||
|
"/global/predict/spend/logs",
|
||||||
|
"/global/activity",
|
||||||
|
"/health/services",
|
||||||
|
] + info_routes
|
||||||
|
|
||||||
internal_user_routes = (
|
internal_user_routes = (
|
||||||
[
|
[
|
||||||
"/key/generate",
|
"/key/generate",
|
||||||
|
|
|
@ -105,6 +105,88 @@ def _get_bearer_token(
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ui_route_allowed(
|
||||||
|
route: str,
|
||||||
|
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
- Route b/w ui token check and normal token check
|
||||||
|
"""
|
||||||
|
# this token is only used for managing the ui
|
||||||
|
allowed_routes = LiteLLMRoutes.ui_routes.value
|
||||||
|
# check if the current route startswith any of the allowed routes
|
||||||
|
if (
|
||||||
|
route is not None
|
||||||
|
and isinstance(route, str)
|
||||||
|
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
|
||||||
|
):
|
||||||
|
# Do something if the current route starts with any of the allowed routes
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
|
||||||
|
return True
|
||||||
|
elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_api_route_allowed(
|
||||||
|
route: str,
|
||||||
|
request: Request,
|
||||||
|
request_data: dict,
|
||||||
|
api_key: str,
|
||||||
|
valid_token: Optional[UserAPIKeyAuth],
|
||||||
|
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
- Route b/w api token check and normal token check
|
||||||
|
"""
|
||||||
|
_user_role = _get_user_role(user_obj=user_obj)
|
||||||
|
|
||||||
|
if valid_token is None:
|
||||||
|
raise Exception("Invalid proxy server token passed")
|
||||||
|
|
||||||
|
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||||
|
non_proxy_admin_allowed_routes_check(
|
||||||
|
user_obj=user_obj,
|
||||||
|
_user_role=_user_role,
|
||||||
|
route=route,
|
||||||
|
request=request,
|
||||||
|
request_data=request_data,
|
||||||
|
api_key=api_key,
|
||||||
|
valid_token=valid_token,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_allowed_route(
|
||||||
|
route: str,
|
||||||
|
token_type: Literal["ui", "api"],
|
||||||
|
request: Request,
|
||||||
|
request_data: dict,
|
||||||
|
api_key: str,
|
||||||
|
valid_token: Optional[UserAPIKeyAuth],
|
||||||
|
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
- Route b/w ui token check and normal token check
|
||||||
|
"""
|
||||||
|
if token_type == "ui":
|
||||||
|
return _is_ui_route_allowed(route=route, user_obj=user_obj)
|
||||||
|
else:
|
||||||
|
return _is_api_route_allowed(
|
||||||
|
route=route,
|
||||||
|
request=request,
|
||||||
|
request_data=request_data,
|
||||||
|
api_key=api_key,
|
||||||
|
valid_token=valid_token,
|
||||||
|
user_obj=user_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def user_api_key_auth( # noqa: PLR0915
|
async def user_api_key_auth( # noqa: PLR0915
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key: str = fastapi.Security(api_key_header),
|
api_key: str = fastapi.Security(api_key_header),
|
||||||
|
@ -1041,81 +1123,27 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
if _end_user_object is not None:
|
if _end_user_object is not None:
|
||||||
valid_token_dict.update(end_user_params)
|
valid_token_dict.update(end_user_params)
|
||||||
|
|
||||||
_user_role = _get_user_role(user_obj=user_obj)
|
|
||||||
|
|
||||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
|
||||||
non_proxy_admin_allowed_routes_check(
|
|
||||||
user_obj=user_obj,
|
|
||||||
_user_role=_user_role,
|
|
||||||
route=route,
|
|
||||||
request=request,
|
|
||||||
request_data=request_data,
|
|
||||||
api_key=api_key,
|
|
||||||
valid_token=valid_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
||||||
# sso/login, ui/login, /key functions and /user functions
|
# sso/login, ui/login, /key functions and /user functions
|
||||||
# this will never be allowed to call /chat/completions
|
# this will never be allowed to call /chat/completions
|
||||||
token_team = getattr(valid_token, "team_id", None)
|
token_team = getattr(valid_token, "team_id", None)
|
||||||
|
token_type: Literal["ui", "api"] = (
|
||||||
|
"ui"
|
||||||
|
if token_team is not None and token_team == "litellm-dashboard"
|
||||||
|
else "api"
|
||||||
|
)
|
||||||
|
_is_route_allowed = _is_allowed_route(
|
||||||
|
route=route,
|
||||||
|
token_type=token_type,
|
||||||
|
user_obj=user_obj,
|
||||||
|
request=request,
|
||||||
|
request_data=request_data,
|
||||||
|
api_key=api_key,
|
||||||
|
valid_token=valid_token,
|
||||||
|
)
|
||||||
|
if not _is_route_allowed:
|
||||||
|
raise HTTPException(401, detail="Invalid route for UI token")
|
||||||
|
|
||||||
if token_team is not None and token_team == "litellm-dashboard":
|
|
||||||
# this token is only used for managing the ui
|
|
||||||
allowed_routes = [
|
|
||||||
"/sso",
|
|
||||||
"/sso/get/ui_settings",
|
|
||||||
"/login",
|
|
||||||
"/key/generate",
|
|
||||||
"/key/update",
|
|
||||||
"/key/info",
|
|
||||||
"/config",
|
|
||||||
"/spend",
|
|
||||||
"/user",
|
|
||||||
"/model/info",
|
|
||||||
"/v2/model/info",
|
|
||||||
"/v2/key/info",
|
|
||||||
"/models",
|
|
||||||
"/v1/models",
|
|
||||||
"/global/spend",
|
|
||||||
"/global/spend/logs",
|
|
||||||
"/global/spend/keys",
|
|
||||||
"/global/spend/models",
|
|
||||||
"/global/predict/spend/logs",
|
|
||||||
"/global/activity",
|
|
||||||
"/health/services",
|
|
||||||
] + LiteLLMRoutes.info_routes.value # type: ignore
|
|
||||||
# check if the current route startswith any of the allowed routes
|
|
||||||
if (
|
|
||||||
route is not None
|
|
||||||
and isinstance(route, str)
|
|
||||||
and any(
|
|
||||||
route.startswith(allowed_route) for allowed_route in allowed_routes
|
|
||||||
)
|
|
||||||
):
|
|
||||||
# Do something if the current route starts with any of the allowed routes
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
|
|
||||||
return UserAPIKeyAuth(
|
|
||||||
api_key=api_key,
|
|
||||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
**valid_token_dict,
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
_has_user_setup_sso()
|
|
||||||
and route in LiteLLMRoutes.sso_only_routes.value
|
|
||||||
):
|
|
||||||
return UserAPIKeyAuth(
|
|
||||||
api_key=api_key,
|
|
||||||
user_role=_user_role, # type: ignore
|
|
||||||
parent_otel_span=parent_otel_span,
|
|
||||||
**valid_token_dict,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
|
|
||||||
)
|
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
# No token was found when looking up in the DB
|
# No token was found when looking up in the DB
|
||||||
raise Exception("Invalid proxy server token passed")
|
raise Exception("Invalid proxy server token passed")
|
||||||
|
|
|
@ -41,6 +41,40 @@ from litellm.proxy.management_helpers.utils import (
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict:
|
||||||
|
if "user_id" in data_json and data_json["user_id"] is None:
|
||||||
|
data_json["user_id"] = str(uuid.uuid4())
|
||||||
|
auto_create_key = data_json.pop("auto_create_key", True)
|
||||||
|
if auto_create_key is False:
|
||||||
|
data_json["table_name"] = (
|
||||||
|
"user" # only create a user, don't create key if 'auto_create_key' set to False
|
||||||
|
)
|
||||||
|
|
||||||
|
is_internal_user = False
|
||||||
|
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
||||||
|
is_internal_user = True
|
||||||
|
if litellm.default_internal_user_params:
|
||||||
|
for key, value in litellm.default_internal_user_params.items():
|
||||||
|
if key not in data_json or data_json[key] is None:
|
||||||
|
data_json[key] = value
|
||||||
|
elif (
|
||||||
|
key == "models"
|
||||||
|
and isinstance(data_json[key], list)
|
||||||
|
and len(data_json[key]) == 0
|
||||||
|
):
|
||||||
|
data_json[key] = value
|
||||||
|
|
||||||
|
if "max_budget" in data_json and data_json["max_budget"] is None:
|
||||||
|
if is_internal_user and litellm.max_internal_user_budget is not None:
|
||||||
|
data_json["max_budget"] = litellm.max_internal_user_budget
|
||||||
|
|
||||||
|
if "budget_duration" in data_json and data_json["budget_duration"] is None:
|
||||||
|
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
||||||
|
data_json["budget_duration"] = litellm.internal_user_budget_duration
|
||||||
|
|
||||||
|
return data_json
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/user/new",
|
"/user/new",
|
||||||
tags=["Internal User management"],
|
tags=["Internal User management"],
|
||||||
|
@ -94,26 +128,7 @@ async def new_user(
|
||||||
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||||
|
|
||||||
data_json = data.json() # type: ignore
|
data_json = data.json() # type: ignore
|
||||||
if "user_id" in data_json and data_json["user_id"] is None:
|
data_json = _update_internal_user_params(data_json, data)
|
||||||
data_json["user_id"] = str(uuid.uuid4())
|
|
||||||
auto_create_key = data_json.pop("auto_create_key", True)
|
|
||||||
if auto_create_key is False:
|
|
||||||
data_json["table_name"] = (
|
|
||||||
"user" # only create a user, don't create key if 'auto_create_key' set to False
|
|
||||||
)
|
|
||||||
|
|
||||||
is_internal_user = False
|
|
||||||
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
|
||||||
is_internal_user = True
|
|
||||||
|
|
||||||
if "max_budget" in data_json and data_json["max_budget"] is None:
|
|
||||||
if is_internal_user and litellm.max_internal_user_budget is not None:
|
|
||||||
data_json["max_budget"] = litellm.max_internal_user_budget
|
|
||||||
|
|
||||||
if "budget_duration" in data_json and data_json["budget_duration"] is None:
|
|
||||||
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
|
||||||
data_json["budget_duration"] = litellm.internal_user_budget_duration
|
|
||||||
|
|
||||||
response = await generate_key_helper_fn(request_type="user", **data_json)
|
response = await generate_key_helper_fn(request_type="user", **data_json)
|
||||||
|
|
||||||
# Admin UI Logic
|
# Admin UI Logic
|
||||||
|
|
|
@ -1585,10 +1585,6 @@ class ProxyConfig:
|
||||||
printed_yaml = copy.deepcopy(config)
|
printed_yaml = copy.deepcopy(config)
|
||||||
printed_yaml.pop("environment_variables", None)
|
printed_yaml.pop("environment_variables", None)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = self._check_for_os_environ_vars(config=config)
|
config = self._check_for_os_environ_vars(config=config)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
|
@ -40,6 +40,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
create_pass_through_route,
|
create_pass_through_route,
|
||||||
)
|
)
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
default_vertex_config = None
|
default_vertex_config = None
|
||||||
|
@ -226,3 +227,53 @@ async def bedrock_proxy_route(
|
||||||
)
|
)
|
||||||
|
|
||||||
return received_value
|
return received_value
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route("/azure/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||||
|
async def azure_proxy_route(
|
||||||
|
endpoint: str,
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
|
||||||
|
if base_target_url is None:
|
||||||
|
raise Exception(
|
||||||
|
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
|
||||||
|
)
|
||||||
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
|
# Ensure endpoint starts with '/' for proper URL construction
|
||||||
|
if not encoded_endpoint.startswith("/"):
|
||||||
|
encoded_endpoint = "/" + encoded_endpoint
|
||||||
|
|
||||||
|
# Construct the full target URL using httpx
|
||||||
|
base_url = httpx.URL(base_target_url)
|
||||||
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
|
# Add or update query parameters
|
||||||
|
azure_api_key = get_secret_str(secret_name="AZURE_API_KEY")
|
||||||
|
|
||||||
|
## check for streaming
|
||||||
|
is_streaming_request = False
|
||||||
|
if "stream" in str(updated_url):
|
||||||
|
is_streaming_request = True
|
||||||
|
|
||||||
|
## CREATE PASS-THROUGH
|
||||||
|
endpoint_func = create_pass_through_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=str(updated_url),
|
||||||
|
custom_headers={
|
||||||
|
"authorization": "Bearer {}".format(azure_api_key),
|
||||||
|
"api-key": "{}".format(azure_api_key),
|
||||||
|
},
|
||||||
|
) # dynamically construct pass-through endpoint based on incoming path
|
||||||
|
received_value = await endpoint_func(
|
||||||
|
request,
|
||||||
|
fastapi_response,
|
||||||
|
user_api_key_dict,
|
||||||
|
stream=is_streaming_request, # type: ignore
|
||||||
|
query_params=dict(request.query_params), # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
return received_value
|
||||||
|
|
21
litellm/types/integrations/argilla.py
Normal file
21
litellm/types/integrations/argilla.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
import os
|
||||||
|
from datetime import datetime as dt
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class ArgillaItem(TypedDict):
|
||||||
|
fields: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class ArgillaPayload(TypedDict):
|
||||||
|
items: List[ArgillaItem]
|
||||||
|
|
||||||
|
|
||||||
|
class ArgillaCredentialsObject(TypedDict):
|
||||||
|
ARGILLA_API_KEY: str
|
||||||
|
ARGILLA_DATASET_NAME: str
|
||||||
|
ARGILLA_BASE_URL: str
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"]
|
|
@ -1821,6 +1821,7 @@ def supports_function_calling(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## CHECK IF MODEL SUPPORTS FUNCTION CALLING ##
|
||||||
model_info = litellm.get_model_info(
|
model_info = litellm.get_model_info(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
@ -4768,6 +4769,8 @@ def get_model_info( # noqa: PLR0915
|
||||||
supports_assistant_prefill=None,
|
supports_assistant_prefill=None,
|
||||||
supports_prompt_caching=None,
|
supports_prompt_caching=None,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
|
||||||
|
return litellm.OllamaConfig().get_model_info(model)
|
||||||
else:
|
else:
|
||||||
"""
|
"""
|
||||||
Check if: (in order of specificity)
|
Check if: (in order of specificity)
|
||||||
|
@ -4964,7 +4967,9 @@ def get_model_info( # noqa: PLR0915
|
||||||
supports_audio_input=_model_info.get("supports_audio_input", False),
|
supports_audio_input=_model_info.get("supports_audio_input", False),
|
||||||
supports_audio_output=_model_info.get("supports_audio_output", False),
|
supports_audio_output=_model_info.get("supports_audio_output", False),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
if "OllamaError" in str(e):
|
||||||
|
raise e
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json.".format(
|
"This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json.".format(
|
||||||
model, custom_llm_provider
|
model, custom_llm_provider
|
||||||
|
|
|
@ -11,6 +11,7 @@ import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import get_model_info
|
from litellm import get_model_info
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_info_simple_model_name():
|
def test_get_model_info_simple_model_name():
|
||||||
|
@ -74,3 +75,25 @@ def test_get_model_info_gemini_pro():
|
||||||
info = litellm.get_model_info("gemini-1.5-pro-002")
|
info = litellm.get_model_info("gemini-1.5-pro-002")
|
||||||
print("info", info)
|
print("info", info)
|
||||||
assert info["key"] == "gemini-1.5-pro-002"
|
assert info["key"] == "gemini-1.5-pro-002"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_info_ollama_chat():
|
||||||
|
from litellm.llms.ollama import OllamaConfig
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
litellm.module_level_client,
|
||||||
|
"post",
|
||||||
|
return_value=MagicMock(
|
||||||
|
json=lambda: {
|
||||||
|
"model_info": {"llama.context_length": 32768},
|
||||||
|
"template": "tools",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
):
|
||||||
|
info = OllamaConfig().get_model_info("mistral")
|
||||||
|
print("info", info)
|
||||||
|
assert info["supports_function_calling"] is True
|
||||||
|
|
||||||
|
info = get_model_info("ollama/mistral")
|
||||||
|
print("info", info)
|
||||||
|
assert info["supports_function_calling"] is True
|
||||||
|
|
|
@ -406,3 +406,29 @@ def test_add_litellm_data_for_backend_llm_call(headers, expected_data):
|
||||||
data = add_litellm_data_for_backend_llm_call(headers)
|
data = add_litellm_data_for_backend_llm_call(headers)
|
||||||
|
|
||||||
assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True)
|
assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_internal_user_params():
|
||||||
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||||
|
_update_internal_user_params,
|
||||||
|
)
|
||||||
|
from litellm.proxy._types import NewUserRequest
|
||||||
|
|
||||||
|
litellm.default_internal_user_params = {
|
||||||
|
"max_budget": 100,
|
||||||
|
"budget_duration": "30d",
|
||||||
|
"models": ["gpt-3.5-turbo"],
|
||||||
|
}
|
||||||
|
|
||||||
|
data = NewUserRequest(user_role="internal_user", user_email="krrish3@berri.ai")
|
||||||
|
data_json = data.model_dump()
|
||||||
|
updated_data_json = _update_internal_user_params(data_json, data)
|
||||||
|
assert updated_data_json["models"] == litellm.default_internal_user_params["models"]
|
||||||
|
assert (
|
||||||
|
updated_data_json["max_budget"]
|
||||||
|
== litellm.default_internal_user_params["max_budget"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
updated_data_json["budget_duration"]
|
||||||
|
== litellm.default_internal_user_params["budget_duration"]
|
||||||
|
)
|
||||||
|
|
|
@ -291,3 +291,28 @@ async def test_auth_with_allowed_routes(route, should_raise_error):
|
||||||
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||||
|
|
||||||
setattr(proxy_server, "general_settings", initial_general_settings)
|
setattr(proxy_server, "general_settings", initial_general_settings)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("route", ["/global/spend/logs", "/key/delete"])
|
||||||
|
def test_is_ui_route_allowed(route):
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import _is_ui_route_allowed
|
||||||
|
from litellm.proxy._types import LiteLLM_UserTable
|
||||||
|
|
||||||
|
received_args: dict = {
|
||||||
|
"route": route,
|
||||||
|
"user_obj": LiteLLM_UserTable(
|
||||||
|
user_id="3b803c0e-666e-4e99-bd5c-6e534c07e297",
|
||||||
|
max_budget=None,
|
||||||
|
spend=0.0,
|
||||||
|
model_max_budget={},
|
||||||
|
model_spend={},
|
||||||
|
user_email="my-test-email@1234.com",
|
||||||
|
models=[],
|
||||||
|
tpm_limit=None,
|
||||||
|
rpm_limit=None,
|
||||||
|
user_role="internal_user",
|
||||||
|
organization_memberships=[],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert _is_ui_route_allowed(**received_args)
|
||||||
|
|
|
@ -448,24 +448,19 @@ def test_token_counter():
|
||||||
# test_token_counter()
|
# test_token_counter()
|
||||||
|
|
||||||
|
|
||||||
def test_supports_function_calling():
|
@pytest.mark.parametrize(
|
||||||
|
"model, expected_bool",
|
||||||
|
[
|
||||||
|
("gpt-3.5-turbo", True),
|
||||||
|
("azure/gpt-4-1106-preview", True),
|
||||||
|
("groq/gemma-7b-it", True),
|
||||||
|
("anthropic.claude-instant-v1", False),
|
||||||
|
("palm/chat-bison", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_supports_function_calling(model, expected_bool):
|
||||||
try:
|
try:
|
||||||
assert litellm.supports_function_calling(model="gpt-3.5-turbo") == True
|
assert litellm.supports_function_calling(model=model) == expected_bool
|
||||||
assert (
|
|
||||||
litellm.supports_function_calling(model="azure/gpt-4-1106-preview") == True
|
|
||||||
)
|
|
||||||
assert litellm.supports_function_calling(model="groq/gemma-7b-it") == True
|
|
||||||
assert (
|
|
||||||
litellm.supports_function_calling(model="anthropic.claude-instant-v1")
|
|
||||||
== False
|
|
||||||
)
|
|
||||||
assert litellm.supports_function_calling(model="palm/chat-bison") == False
|
|
||||||
assert litellm.supports_function_calling(model="ollama/llama2") == False
|
|
||||||
assert (
|
|
||||||
litellm.supports_function_calling(model="anthropic.claude-instant-v1")
|
|
||||||
== False
|
|
||||||
)
|
|
||||||
assert litellm.supports_function_calling(model="claude-2") == False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue