forked from phoenix/litellm-mirror
feat(custom_logger.py): expose new async_dataset_hook
for modifying… (#6331)
* feat(custom_logger.py): expose new `async_dataset_hook` for modifying/rejecting argilla items before logging Allows user more control on what gets logged to argilla for annotations * feat(google_ai_studio_endpoints.py): add new `/azure/*` pass through route enables pass-through for azure provider * feat(utils.py): support checking ollama `/api/show` endpoint for retrieving ollama model info Fixes https://github.com/BerriAI/litellm/issues/6322 * fix(user_api_key_auth.py): add `/key/delete` to an allowed_ui_routes Fixes https://github.com/BerriAI/litellm/issues/6236 * fix(user_api_key_auth.py): remove type ignore * fix(user_api_key_auth.py): route ui vs. api token checks differently Fixes https://github.com/BerriAI/litellm/issues/6238 * feat(internal_user_endpoints.py): support setting models as a default internal user param Closes https://github.com/BerriAI/litellm/issues/6239 * fix(user_api_key_auth.py): fix exception string * fix(user_api_key_auth.py): fix error string * fix: fix test
This commit is contained in:
parent
7cc12bd5c6
commit
905ebeb924
16 changed files with 422 additions and 153 deletions
|
@ -207,6 +207,7 @@ litellm_settings:
|
|||
user_role: "internal_user" # one of "internal_user", "internal_user_viewer", "proxy_admin", "proxy_admin_viewer". New SSO users not in litellm will be created as this user
|
||||
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
|
||||
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
|
||||
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
|
||||
|
||||
|
||||
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
|
||||
|
|
|
@ -21,53 +21,22 @@ from pydantic import BaseModel # type: ignore
|
|||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.llms.prompt_templates.common_utils import get_content_from_model_response
|
||||
from litellm.types.integrations.argilla import (
|
||||
SUPPORTED_PAYLOAD_FIELDS,
|
||||
ArgillaCredentialsObject,
|
||||
ArgillaItem,
|
||||
ArgillaPayload,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class LangsmithInputs(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: Optional[List[Any]] = None
|
||||
stream: Optional[bool] = None
|
||||
call_type: Optional[str] = None
|
||||
litellm_call_id: Optional[str] = None
|
||||
completion_start_time: Optional[datetime] = None
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
custom_llm_provider: Optional[str] = None
|
||||
input: Optional[List[Any]] = None
|
||||
log_event_type: Optional[str] = None
|
||||
original_response: Optional[Any] = None
|
||||
response_cost: Optional[float] = None
|
||||
|
||||
# LiteLLM Virtual Key specific fields
|
||||
user_api_key: Optional[str] = None
|
||||
user_api_key_user_id: Optional[str] = None
|
||||
user_api_key_team_alias: Optional[str] = None
|
||||
|
||||
|
||||
class ArgillaItem(TypedDict):
|
||||
fields: Dict[str, Any]
|
||||
|
||||
|
||||
class ArgillaPayload(TypedDict):
|
||||
items: List[ArgillaItem]
|
||||
|
||||
|
||||
class ArgillaCredentialsObject(TypedDict):
|
||||
ARGILLA_API_KEY: str
|
||||
ARGILLA_DATASET_NAME: str
|
||||
ARGILLA_BASE_URL: str
|
||||
|
||||
|
||||
SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"]
|
||||
|
||||
|
||||
def is_serializable(value):
|
||||
non_serializable_types = (
|
||||
types.CoroutineType,
|
||||
|
@ -215,7 +184,7 @@ class ArgillaLogger(CustomBatchLogger):
|
|||
|
||||
def _prepare_log_data(
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
) -> ArgillaItem:
|
||||
) -> Optional[ArgillaItem]:
|
||||
try:
|
||||
# Ensure everything in the payload is converted to str
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
|
@ -235,6 +204,7 @@ class ArgillaLogger(CustomBatchLogger):
|
|||
argilla_item["fields"][k] = argilla_response
|
||||
else:
|
||||
argilla_item["fields"][k] = payload.get(v, None)
|
||||
|
||||
return argilla_item
|
||||
except Exception:
|
||||
raise
|
||||
|
@ -294,6 +264,9 @@ class ArgillaLogger(CustomBatchLogger):
|
|||
response_obj,
|
||||
)
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
if data is None:
|
||||
return
|
||||
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
|
||||
|
@ -321,7 +294,25 @@ class ArgillaLogger(CustomBatchLogger):
|
|||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
|
||||
## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomLogger):
|
||||
try:
|
||||
if data is None:
|
||||
break
|
||||
data = await callback.async_dataset_hook(data, payload)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
if data is None:
|
||||
return
|
||||
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
"Langsmith logging: queue length %s, batch size %s",
|
||||
|
|
|
@ -10,6 +10,7 @@ from pydantic import BaseModel
|
|||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.integrations.argilla import ArgillaItem
|
||||
from litellm.types.llms.openai import ChatCompletionRequest
|
||||
from litellm.types.services import ServiceLoggerPayload
|
||||
from litellm.types.utils import (
|
||||
|
@ -17,6 +18,7 @@ from litellm.types.utils import (
|
|||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
|
||||
|
@ -108,6 +110,20 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
"""
|
||||
pass
|
||||
|
||||
### DATASET HOOKS #### - currently only used for Argilla
|
||||
|
||||
async def async_dataset_hook(
|
||||
self,
|
||||
logged_item: ArgillaItem,
|
||||
standard_logging_payload: Optional[StandardLoggingPayload],
|
||||
) -> Optional[ArgillaItem]:
|
||||
"""
|
||||
- Decide if the result should be logged to Argilla.
|
||||
- Modify the result before logging to Argilla.
|
||||
- Return None if the result should not be logged to Argilla.
|
||||
"""
|
||||
raise NotImplementedError("async_dataset_hook not implemented")
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
"""
|
||||
Control the modify incoming / outgoung data before calling the model
|
||||
|
|
|
@ -14,7 +14,8 @@ import requests # type: ignore
|
|||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.types.utils import ProviderField, StreamingChoices
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices
|
||||
|
||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
|
@ -163,6 +164,56 @@ class OllamaConfig:
|
|||
"response_format",
|
||||
]
|
||||
|
||||
def _supports_function_calling(self, ollama_model_info: dict) -> bool:
|
||||
"""
|
||||
Check if the 'template' field in the ollama_model_info contains a 'tools' or 'function' key.
|
||||
"""
|
||||
_template: str = str(ollama_model_info.get("template", "") or "")
|
||||
return "tools" in _template.lower()
|
||||
|
||||
def _get_max_tokens(self, ollama_model_info: dict) -> Optional[int]:
|
||||
_model_info: dict = ollama_model_info.get("model_info", {})
|
||||
|
||||
for k, v in _model_info.items():
|
||||
if "context_length" in k:
|
||||
return v
|
||||
return None
|
||||
|
||||
def get_model_info(self, model: str) -> ModelInfo:
|
||||
"""
|
||||
curl http://localhost:11434/api/show -d '{
|
||||
"name": "mistral"
|
||||
}'
|
||||
"""
|
||||
api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434"
|
||||
|
||||
try:
|
||||
response = litellm.module_level_client.post(
|
||||
url=f"{api_base}/api/show",
|
||||
json={"name": model},
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"OllamaError: Error getting model info for {model}. Set Ollama API Base via `OLLAMA_API_BASE` environment variable. Error: {e}"
|
||||
)
|
||||
|
||||
model_info = response.json()
|
||||
|
||||
_max_tokens: Optional[int] = self._get_max_tokens(model_info)
|
||||
|
||||
return ModelInfo(
|
||||
key=model,
|
||||
litellm_provider="ollama",
|
||||
mode="chat",
|
||||
supported_openai_params=self.get_supported_openai_params(),
|
||||
supports_function_calling=self._supports_function_calling(model_info),
|
||||
input_cost_per_token=0.0,
|
||||
output_cost_per_token=0.0,
|
||||
max_tokens=_max_tokens,
|
||||
max_input_tokens=_max_tokens,
|
||||
max_output_tokens=_max_tokens,
|
||||
)
|
||||
|
||||
|
||||
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
|
||||
# and convert to jpeg if necessary.
|
||||
|
|
|
@ -335,6 +335,31 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/metrics",
|
||||
]
|
||||
|
||||
ui_routes = [
|
||||
"/sso",
|
||||
"/sso/get/ui_settings",
|
||||
"/login",
|
||||
"/key/generate",
|
||||
"/key/update",
|
||||
"/key/info",
|
||||
"/key/delete",
|
||||
"/config",
|
||||
"/spend",
|
||||
"/user",
|
||||
"/model/info",
|
||||
"/v2/model/info",
|
||||
"/v2/key/info",
|
||||
"/models",
|
||||
"/v1/models",
|
||||
"/global/spend",
|
||||
"/global/spend/logs",
|
||||
"/global/spend/keys",
|
||||
"/global/spend/models",
|
||||
"/global/predict/spend/logs",
|
||||
"/global/activity",
|
||||
"/health/services",
|
||||
] + info_routes
|
||||
|
||||
internal_user_routes = (
|
||||
[
|
||||
"/key/generate",
|
||||
|
|
|
@ -105,6 +105,88 @@ def _get_bearer_token(
|
|||
return api_key
|
||||
|
||||
|
||||
def _is_ui_route_allowed(
|
||||
route: str,
|
||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
- Route b/w ui token check and normal token check
|
||||
"""
|
||||
# this token is only used for managing the ui
|
||||
allowed_routes = LiteLLMRoutes.ui_routes.value
|
||||
# check if the current route startswith any of the allowed routes
|
||||
if (
|
||||
route is not None
|
||||
and isinstance(route, str)
|
||||
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
|
||||
):
|
||||
# Do something if the current route starts with any of the allowed routes
|
||||
return True
|
||||
else:
|
||||
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
|
||||
return True
|
||||
elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value:
|
||||
return True
|
||||
else:
|
||||
raise Exception(
|
||||
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
|
||||
)
|
||||
|
||||
|
||||
def _is_api_route_allowed(
|
||||
route: str,
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
api_key: str,
|
||||
valid_token: Optional[UserAPIKeyAuth],
|
||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
- Route b/w api token check and normal token check
|
||||
"""
|
||||
_user_role = _get_user_role(user_obj=user_obj)
|
||||
|
||||
if valid_token is None:
|
||||
raise Exception("Invalid proxy server token passed")
|
||||
|
||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||
non_proxy_admin_allowed_routes_check(
|
||||
user_obj=user_obj,
|
||||
_user_role=_user_role,
|
||||
route=route,
|
||||
request=request,
|
||||
request_data=request_data,
|
||||
api_key=api_key,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _is_allowed_route(
|
||||
route: str,
|
||||
token_type: Literal["ui", "api"],
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
api_key: str,
|
||||
valid_token: Optional[UserAPIKeyAuth],
|
||||
user_obj: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
- Route b/w ui token check and normal token check
|
||||
"""
|
||||
if token_type == "ui":
|
||||
return _is_ui_route_allowed(route=route, user_obj=user_obj)
|
||||
else:
|
||||
return _is_api_route_allowed(
|
||||
route=route,
|
||||
request=request,
|
||||
request_data=request_data,
|
||||
api_key=api_key,
|
||||
valid_token=valid_token,
|
||||
user_obj=user_obj,
|
||||
)
|
||||
|
||||
|
||||
async def user_api_key_auth( # noqa: PLR0915
|
||||
request: Request,
|
||||
api_key: str = fastapi.Security(api_key_header),
|
||||
|
@ -1041,81 +1123,27 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
if _end_user_object is not None:
|
||||
valid_token_dict.update(end_user_params)
|
||||
|
||||
_user_role = _get_user_role(user_obj=user_obj)
|
||||
|
||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||
non_proxy_admin_allowed_routes_check(
|
||||
user_obj=user_obj,
|
||||
_user_role=_user_role,
|
||||
route=route,
|
||||
request=request,
|
||||
request_data=request_data,
|
||||
api_key=api_key,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
||||
# sso/login, ui/login, /key functions and /user functions
|
||||
# this will never be allowed to call /chat/completions
|
||||
token_team = getattr(valid_token, "team_id", None)
|
||||
token_type: Literal["ui", "api"] = (
|
||||
"ui"
|
||||
if token_team is not None and token_team == "litellm-dashboard"
|
||||
else "api"
|
||||
)
|
||||
_is_route_allowed = _is_allowed_route(
|
||||
route=route,
|
||||
token_type=token_type,
|
||||
user_obj=user_obj,
|
||||
request=request,
|
||||
request_data=request_data,
|
||||
api_key=api_key,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
if not _is_route_allowed:
|
||||
raise HTTPException(401, detail="Invalid route for UI token")
|
||||
|
||||
if token_team is not None and token_team == "litellm-dashboard":
|
||||
# this token is only used for managing the ui
|
||||
allowed_routes = [
|
||||
"/sso",
|
||||
"/sso/get/ui_settings",
|
||||
"/login",
|
||||
"/key/generate",
|
||||
"/key/update",
|
||||
"/key/info",
|
||||
"/config",
|
||||
"/spend",
|
||||
"/user",
|
||||
"/model/info",
|
||||
"/v2/model/info",
|
||||
"/v2/key/info",
|
||||
"/models",
|
||||
"/v1/models",
|
||||
"/global/spend",
|
||||
"/global/spend/logs",
|
||||
"/global/spend/keys",
|
||||
"/global/spend/models",
|
||||
"/global/predict/spend/logs",
|
||||
"/global/activity",
|
||||
"/health/services",
|
||||
] + LiteLLMRoutes.info_routes.value # type: ignore
|
||||
# check if the current route startswith any of the allowed routes
|
||||
if (
|
||||
route is not None
|
||||
and isinstance(route, str)
|
||||
and any(
|
||||
route.startswith(allowed_route) for allowed_route in allowed_routes
|
||||
)
|
||||
):
|
||||
# Do something if the current route starts with any of the allowed routes
|
||||
pass
|
||||
else:
|
||||
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
|
||||
return UserAPIKeyAuth(
|
||||
api_key=api_key,
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
parent_otel_span=parent_otel_span,
|
||||
**valid_token_dict,
|
||||
)
|
||||
elif (
|
||||
_has_user_setup_sso()
|
||||
and route in LiteLLMRoutes.sso_only_routes.value
|
||||
):
|
||||
return UserAPIKeyAuth(
|
||||
api_key=api_key,
|
||||
user_role=_user_role, # type: ignore
|
||||
parent_otel_span=parent_otel_span,
|
||||
**valid_token_dict,
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
|
||||
)
|
||||
if valid_token is None:
|
||||
# No token was found when looking up in the DB
|
||||
raise Exception("Invalid proxy server token passed")
|
||||
|
|
|
@ -41,6 +41,40 @@ from litellm.proxy.management_helpers.utils import (
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict:
|
||||
if "user_id" in data_json and data_json["user_id"] is None:
|
||||
data_json["user_id"] = str(uuid.uuid4())
|
||||
auto_create_key = data_json.pop("auto_create_key", True)
|
||||
if auto_create_key is False:
|
||||
data_json["table_name"] = (
|
||||
"user" # only create a user, don't create key if 'auto_create_key' set to False
|
||||
)
|
||||
|
||||
is_internal_user = False
|
||||
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
||||
is_internal_user = True
|
||||
if litellm.default_internal_user_params:
|
||||
for key, value in litellm.default_internal_user_params.items():
|
||||
if key not in data_json or data_json[key] is None:
|
||||
data_json[key] = value
|
||||
elif (
|
||||
key == "models"
|
||||
and isinstance(data_json[key], list)
|
||||
and len(data_json[key]) == 0
|
||||
):
|
||||
data_json[key] = value
|
||||
|
||||
if "max_budget" in data_json and data_json["max_budget"] is None:
|
||||
if is_internal_user and litellm.max_internal_user_budget is not None:
|
||||
data_json["max_budget"] = litellm.max_internal_user_budget
|
||||
|
||||
if "budget_duration" in data_json and data_json["budget_duration"] is None:
|
||||
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
||||
data_json["budget_duration"] = litellm.internal_user_budget_duration
|
||||
|
||||
return data_json
|
||||
|
||||
|
||||
@router.post(
|
||||
"/user/new",
|
||||
tags=["Internal User management"],
|
||||
|
@ -94,26 +128,7 @@ async def new_user(
|
|||
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||
|
||||
data_json = data.json() # type: ignore
|
||||
if "user_id" in data_json and data_json["user_id"] is None:
|
||||
data_json["user_id"] = str(uuid.uuid4())
|
||||
auto_create_key = data_json.pop("auto_create_key", True)
|
||||
if auto_create_key is False:
|
||||
data_json["table_name"] = (
|
||||
"user" # only create a user, don't create key if 'auto_create_key' set to False
|
||||
)
|
||||
|
||||
is_internal_user = False
|
||||
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
||||
is_internal_user = True
|
||||
|
||||
if "max_budget" in data_json and data_json["max_budget"] is None:
|
||||
if is_internal_user and litellm.max_internal_user_budget is not None:
|
||||
data_json["max_budget"] = litellm.max_internal_user_budget
|
||||
|
||||
if "budget_duration" in data_json and data_json["budget_duration"] is None:
|
||||
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
||||
data_json["budget_duration"] = litellm.internal_user_budget_duration
|
||||
|
||||
data_json = _update_internal_user_params(data_json, data)
|
||||
response = await generate_key_helper_fn(request_type="user", **data_json)
|
||||
|
||||
# Admin UI Logic
|
||||
|
|
|
@ -1585,10 +1585,6 @@ class ProxyConfig:
|
|||
printed_yaml = copy.deepcopy(config)
|
||||
printed_yaml.pop("environment_variables", None)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
|
||||
)
|
||||
|
||||
config = self._check_for_os_environ_vars(config=config)
|
||||
|
||||
return config
|
||||
|
|
|
@ -40,6 +40,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
create_pass_through_route,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
router = APIRouter()
|
||||
default_vertex_config = None
|
||||
|
@ -226,3 +227,53 @@ async def bedrock_proxy_route(
|
|||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route("/azure/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def azure_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
|
||||
if base_target_url is None:
|
||||
raise Exception(
|
||||
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
|
||||
)
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
azure_api_key = get_secret_str(secret_name="AZURE_API_KEY")
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={
|
||||
"authorization": "Bearer {}".format(azure_api_key),
|
||||
"api-key": "{}".format(azure_api_key),
|
||||
},
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
query_params=dict(request.query_params), # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
|
21
litellm/types/integrations/argilla.py
Normal file
21
litellm/types/integrations/argilla.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
import os
|
||||
from datetime import datetime as dt
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict
|
||||
|
||||
|
||||
class ArgillaItem(TypedDict):
|
||||
fields: Dict[str, Any]
|
||||
|
||||
|
||||
class ArgillaPayload(TypedDict):
|
||||
items: List[ArgillaItem]
|
||||
|
||||
|
||||
class ArgillaCredentialsObject(TypedDict):
|
||||
ARGILLA_API_KEY: str
|
||||
ARGILLA_DATASET_NAME: str
|
||||
ARGILLA_BASE_URL: str
|
||||
|
||||
|
||||
SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"]
|
|
@ -1821,6 +1821,7 @@ def supports_function_calling(
|
|||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
## CHECK IF MODEL SUPPORTS FUNCTION CALLING ##
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
@ -4768,6 +4769,8 @@ def get_model_info( # noqa: PLR0915
|
|||
supports_assistant_prefill=None,
|
||||
supports_prompt_caching=None,
|
||||
)
|
||||
elif custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
|
||||
return litellm.OllamaConfig().get_model_info(model)
|
||||
else:
|
||||
"""
|
||||
Check if: (in order of specificity)
|
||||
|
@ -4964,7 +4967,9 @@ def get_model_info( # noqa: PLR0915
|
|||
supports_audio_input=_model_info.get("supports_audio_input", False),
|
||||
supports_audio_output=_model_info.get("supports_audio_output", False),
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
if "OllamaError" in str(e):
|
||||
raise e
|
||||
raise Exception(
|
||||
"This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json.".format(
|
||||
model, custom_llm_provider
|
||||
|
|
|
@ -11,6 +11,7 @@ import pytest
|
|||
|
||||
import litellm
|
||||
from litellm import get_model_info
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
def test_get_model_info_simple_model_name():
|
||||
|
@ -74,3 +75,25 @@ def test_get_model_info_gemini_pro():
|
|||
info = litellm.get_model_info("gemini-1.5-pro-002")
|
||||
print("info", info)
|
||||
assert info["key"] == "gemini-1.5-pro-002"
|
||||
|
||||
|
||||
def test_get_model_info_ollama_chat():
|
||||
from litellm.llms.ollama import OllamaConfig
|
||||
|
||||
with patch.object(
|
||||
litellm.module_level_client,
|
||||
"post",
|
||||
return_value=MagicMock(
|
||||
json=lambda: {
|
||||
"model_info": {"llama.context_length": 32768},
|
||||
"template": "tools",
|
||||
}
|
||||
),
|
||||
):
|
||||
info = OllamaConfig().get_model_info("mistral")
|
||||
print("info", info)
|
||||
assert info["supports_function_calling"] is True
|
||||
|
||||
info = get_model_info("ollama/mistral")
|
||||
print("info", info)
|
||||
assert info["supports_function_calling"] is True
|
||||
|
|
|
@ -406,3 +406,29 @@ def test_add_litellm_data_for_backend_llm_call(headers, expected_data):
|
|||
data = add_litellm_data_for_backend_llm_call(headers)
|
||||
|
||||
assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True)
|
||||
|
||||
|
||||
def test_update_internal_user_params():
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
_update_internal_user_params,
|
||||
)
|
||||
from litellm.proxy._types import NewUserRequest
|
||||
|
||||
litellm.default_internal_user_params = {
|
||||
"max_budget": 100,
|
||||
"budget_duration": "30d",
|
||||
"models": ["gpt-3.5-turbo"],
|
||||
}
|
||||
|
||||
data = NewUserRequest(user_role="internal_user", user_email="krrish3@berri.ai")
|
||||
data_json = data.model_dump()
|
||||
updated_data_json = _update_internal_user_params(data_json, data)
|
||||
assert updated_data_json["models"] == litellm.default_internal_user_params["models"]
|
||||
assert (
|
||||
updated_data_json["max_budget"]
|
||||
== litellm.default_internal_user_params["max_budget"]
|
||||
)
|
||||
assert (
|
||||
updated_data_json["budget_duration"]
|
||||
== litellm.default_internal_user_params["budget_duration"]
|
||||
)
|
||||
|
|
|
@ -291,3 +291,28 @@ async def test_auth_with_allowed_routes(route, should_raise_error):
|
|||
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||
|
||||
setattr(proxy_server, "general_settings", initial_general_settings)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("route", ["/global/spend/logs", "/key/delete"])
|
||||
def test_is_ui_route_allowed(route):
|
||||
from litellm.proxy.auth.user_api_key_auth import _is_ui_route_allowed
|
||||
from litellm.proxy._types import LiteLLM_UserTable
|
||||
|
||||
received_args: dict = {
|
||||
"route": route,
|
||||
"user_obj": LiteLLM_UserTable(
|
||||
user_id="3b803c0e-666e-4e99-bd5c-6e534c07e297",
|
||||
max_budget=None,
|
||||
spend=0.0,
|
||||
model_max_budget={},
|
||||
model_spend={},
|
||||
user_email="my-test-email@1234.com",
|
||||
models=[],
|
||||
tpm_limit=None,
|
||||
rpm_limit=None,
|
||||
user_role="internal_user",
|
||||
organization_memberships=[],
|
||||
),
|
||||
}
|
||||
|
||||
assert _is_ui_route_allowed(**received_args)
|
||||
|
|
|
@ -448,24 +448,19 @@ def test_token_counter():
|
|||
# test_token_counter()
|
||||
|
||||
|
||||
def test_supports_function_calling():
|
||||
@pytest.mark.parametrize(
|
||||
"model, expected_bool",
|
||||
[
|
||||
("gpt-3.5-turbo", True),
|
||||
("azure/gpt-4-1106-preview", True),
|
||||
("groq/gemma-7b-it", True),
|
||||
("anthropic.claude-instant-v1", False),
|
||||
("palm/chat-bison", False),
|
||||
],
|
||||
)
|
||||
def test_supports_function_calling(model, expected_bool):
|
||||
try:
|
||||
assert litellm.supports_function_calling(model="gpt-3.5-turbo") == True
|
||||
assert (
|
||||
litellm.supports_function_calling(model="azure/gpt-4-1106-preview") == True
|
||||
)
|
||||
assert litellm.supports_function_calling(model="groq/gemma-7b-it") == True
|
||||
assert (
|
||||
litellm.supports_function_calling(model="anthropic.claude-instant-v1")
|
||||
== False
|
||||
)
|
||||
assert litellm.supports_function_calling(model="palm/chat-bison") == False
|
||||
assert litellm.supports_function_calling(model="ollama/llama2") == False
|
||||
assert (
|
||||
litellm.supports_function_calling(model="anthropic.claude-instant-v1")
|
||||
== False
|
||||
)
|
||||
assert litellm.supports_function_calling(model="claude-2") == False
|
||||
assert litellm.supports_function_calling(model=model) == expected_bool
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue