LiteLLM Minor Fixes & Improvements (10/17/2024) (#6293)

* fix(ui_sso.py): fix faulty admin only check

Fixes https://github.com/BerriAI/litellm/issues/6286

* refactor(sso_helper_utils.py): refactor /sso/callback to use helper utils, covered by unit testing

Prevent future regressions

* feat(prompt_factory): support 'ensure_alternating_roles' param

Closes https://github.com/BerriAI/litellm/issues/6257

* fix(proxy/utils.py): add dailytagspend to expected views

* feat(auth_utils.py): support setting regex for clientside auth credentials

Fixes https://github.com/BerriAI/litellm/issues/6203

* build(cookbook): add tutorial for mlflow + langchain + litellm proxy tracing

* feat(argilla.py): add argilla logging integration

Closes https://github.com/BerriAI/litellm/issues/6201

* fix: fix linting errors

* fix: fix ruff error

* test: fix test

* fix: update vertex ai assumption - parts not always guaranteed (#6296)

* docs(configs.md): add argila env var to docs
This commit is contained in:
Krish Dholakia 2024-10-17 22:09:11 -07:00 committed by GitHub
parent 5e381caf75
commit f252350881
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1388 additions and 43 deletions

View file

@ -0,0 +1,312 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Databricks Notebook with MLFlow AutoLogging for LiteLLM Proxy calls\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "5e2812ed-8000-4793-b090-49a31464d810",
"showTitle": false,
"title": ""
}
},
"outputs": [],
"source": [
"%pip install -U -qqqq databricks-agents mlflow langchain==0.3.1 langchain-core==0.3.6 "
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "52530b37-1860-4bba-a6c1-723de83bc58f",
"showTitle": false,
"title": ""
}
},
"outputs": [],
"source": [
"%pip install \"langchain-openai<=0.3.1\""
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "43c6f4b1-e2d5-431c-b1a2-b97df7707d59",
"showTitle": false,
"title": ""
}
},
"outputs": [],
"source": [
"# Before logging this chain using the driver notebook, you must comment out this line.\n",
"dbutils.library.restartPython() "
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "88eb8dd7-16b1-480b-aa70-cd429ef87159",
"showTitle": false,
"title": ""
}
},
"outputs": [],
"source": [
"import mlflow\n",
"from operator import itemgetter\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import PromptTemplate\n",
"from langchain_core.runnables import RunnableLambda\n",
"from langchain_databricks import ChatDatabricks\n",
"from langchain_openai import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "f0fdca8f-6f6f-407c-ad4a-0d5a2778728e",
"showTitle": false,
"title": ""
}
},
"outputs": [],
"source": [
"import mlflow\n",
"mlflow.langchain.autolog()"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "2ef67315-e468-4d60-a318-98c2cac75bc4",
"showTitle": false,
"title": ""
}
},
"outputs": [],
"source": [
"# These helper functions parse the `messages` array.\n",
"\n",
"# Return the string contents of the most recent message from the user\n",
"def extract_user_query_string(chat_messages_array):\n",
" return chat_messages_array[-1][\"content\"]\n",
"\n",
"\n",
"# Return the chat history, which is is everything before the last question\n",
"def extract_chat_history(chat_messages_array):\n",
" return chat_messages_array[:-1]"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "17708467-1976-48bd-94a0-8c7895cfae3b",
"showTitle": false,
"title": ""
}
},
"outputs": [],
"source": [
"model = ChatOpenAI(\n",
" openai_api_base=\"LITELLM_PROXY_BASE_URL\", # e.g.: http://0.0.0.0:4000\n",
" model = \"gpt-3.5-turbo\", # LITELLM 'model_name'\n",
" temperature=0.1, \n",
" api_key=\"LITELLM_PROXY_API_KEY\" # e.g.: \"sk-1234\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "a5f2c2af-82f7-470d-b559-47b67fb00cda",
"showTitle": false,
"title": ""
}
},
"outputs": [],
"source": [
"############\n",
"# Prompt Template for generation\n",
"############\n",
"prompt = PromptTemplate(\n",
" template=\"You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}\",\n",
" input_variables=[\"question\"],\n",
")\n",
"\n",
"############\n",
"# FM for generation\n",
"# ChatDatabricks accepts any /llm/v1/chat model serving endpoint\n",
"############\n",
"model = ChatDatabricks(\n",
" endpoint=\"databricks-dbrx-instruct\",\n",
" extra_params={\"temperature\": 0.01, \"max_tokens\": 500},\n",
")\n",
"\n",
"\n",
"############\n",
"# Simple chain\n",
"############\n",
"# The framework requires the chain to return a string value.\n",
"chain = (\n",
" {\n",
" \"question\": itemgetter(\"messages\")\n",
" | RunnableLambda(extract_user_query_string),\n",
" \"chat_history\": itemgetter(\"messages\") | RunnableLambda(extract_chat_history),\n",
" }\n",
" | prompt\n",
" | model\n",
" | StrOutputParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "366edd90-62a1-4d6f-8a65-0211fb24ca02",
"showTitle": false,
"title": ""
}
},
"outputs": [
{
"data": {
"text/plain": [
"'Hello there! I\\'m here to help with your questions. Regarding your query about \"rag,\" it\\'s not something typically associated with a \"hello world\" bot, but I\\'m happy to explain!\\n\\nRAG, or Remote Angular GUI, is a tool that allows you to create and manage Angular applications remotely. It\\'s a way to develop and test Angular components and applications without needing to set up a local development environment. This can be particularly useful for teams working on distributed systems or for developers who prefer to work in a cloud-based environment.\\n\\nI hope this explanation of RAG has been helpful and interesting! If you have any other questions or need further clarification, feel free to ask.'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"application/databricks.mlflow.trace": "\"tr-ea2226413395413ba2cf52cffc523502\"",
"text/plain": [
"Trace(request_id=tr-ea2226413395413ba2cf52cffc523502)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# This is the same input your chain's REST API will accept.\n",
"question = {\n",
" \"messages\": [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"what is rag?\",\n",
" },\n",
" ]\n",
"}\n",
"\n",
"chain.invoke(question)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "5d68e37d-0980-4a02-bf8d-885c3853f6c1",
"showTitle": false,
"title": ""
}
},
"outputs": [],
"source": [
"mlflow.models.set_model(model=model)"
]
}
],
"metadata": {
"application/vnd.databricks.v1+notebook": {
"dashboards": [],
"environmentMetadata": null,
"language": "python",
"notebookMetadata": {
"pythonIndentUnit": 4
},
"notebookName": "Untitled Notebook 2024-10-16 19:35:16",
"widgets": {}
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View file

@ -873,6 +873,11 @@ router_settings:
| ALLOWED_EMAIL_DOMAINS | List of email domains allowed for access
| ARIZE_API_KEY | API key for Arize platform integration
| ARIZE_SPACE_KEY | Space key for Arize platform
| ARGILLA_BATCH_SIZE | Batch size for Argilla logging
| ARGILLA_API_KEY | API key for Argilla platform
| ARGILLA_SAMPLING_RATE | Sampling rate for Argilla logging
| ARGILLA_DATASET_NAME | Dataset name for Argilla logging
| ARGILLA_BASE_URL | Base URL for Argilla service
| ATHINA_API_KEY | API key for Athina service
| AUTH_STRATEGY | Strategy used for authentication (e.g., OAuth, API key)
| AWS_ACCESS_KEY_ID | Access Key ID for AWS services

View file

@ -54,6 +54,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
"langtrace",
"gcs_bucket",
"opik",
"argilla",
]
_known_custom_logger_compatible_callbacks: List = list(
get_args(_custom_logger_compatible_callbacks_literal)
@ -61,6 +62,8 @@ _known_custom_logger_compatible_callbacks: List = list(
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
langfuse_default_tags: Optional[List[str]] = None
langsmith_batch_size: Optional[int] = None
argilla_batch_size: Optional[int] = None
argilla_transformation_object: Optional[Dict[str, Any]] = None
_async_input_callback: List[Callable] = (
[]
) # internal variable - async custom callbacks are routed here.

View file

@ -693,7 +693,11 @@ def completion_cost(
completion_response, RerankResponse
):
meta_obj = completion_response.meta
billed_units = meta_obj.get("billed_units", {}) or {}
if meta_obj is not None:
billed_units = meta_obj.get("billed_units", {}) or {}
else:
billed_units = {}
search_units = (
billed_units.get("search_units") or 1
) # cohere charges per request by default.

View file

@ -0,0 +1,410 @@
"""
Send logs to Argilla for annotation
"""
import asyncio
import json
import os
import random
import time
import traceback
import types
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, TypedDict, Union
import dotenv # type: ignore
import httpx
import requests # type: ignore
from pydantic import BaseModel # type: ignore
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.llms.prompt_templates.common_utils import get_content_from_model_response
from litellm.types.utils import StandardLoggingPayload
class LangsmithInputs(BaseModel):
model: Optional[str] = None
messages: Optional[List[Any]] = None
stream: Optional[bool] = None
call_type: Optional[str] = None
litellm_call_id: Optional[str] = None
completion_start_time: Optional[datetime] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
custom_llm_provider: Optional[str] = None
input: Optional[List[Any]] = None
log_event_type: Optional[str] = None
original_response: Optional[Any] = None
response_cost: Optional[float] = None
# LiteLLM Virtual Key specific fields
user_api_key: Optional[str] = None
user_api_key_user_id: Optional[str] = None
user_api_key_team_alias: Optional[str] = None
class ArgillaItem(TypedDict):
fields: Dict[str, Any]
class ArgillaPayload(TypedDict):
items: List[ArgillaItem]
class ArgillaCredentialsObject(TypedDict):
ARGILLA_API_KEY: str
ARGILLA_DATASET_NAME: str
ARGILLA_BASE_URL: str
SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"]
def is_serializable(value):
non_serializable_types = (
types.CoroutineType,
types.FunctionType,
types.GeneratorType,
BaseModel,
)
return not isinstance(value, non_serializable_types)
class ArgillaLogger(CustomBatchLogger):
def __init__(
self,
argilla_api_key: Optional[str] = None,
argilla_dataset_name: Optional[str] = None,
argilla_base_url: Optional[str] = None,
**kwargs,
):
if litellm.argilla_transformation_object is None:
raise Exception(
"'litellm.argilla_transformation_object' is required, to log your payload to Argilla."
)
self.validate_argilla_transformation_object(
litellm.argilla_transformation_object
)
self.argilla_transformation_object = litellm.argilla_transformation_object
self.default_credentials = self.get_credentials_from_env(
argilla_api_key=argilla_api_key,
argilla_dataset_name=argilla_dataset_name,
argilla_base_url=argilla_base_url,
)
self.sampling_rate: float = (
float(os.getenv("ARGILLA_SAMPLING_RATE")) # type: ignore
if os.getenv("ARGILLA_SAMPLING_RATE") is not None
and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() # type: ignore
else 1.0
)
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
_batch_size = (
os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size
)
if _batch_size:
self.batch_size = int(_batch_size)
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock)
def validate_argilla_transformation_object(
self, argilla_transformation_object: Dict[str, Any]
):
if not isinstance(argilla_transformation_object, dict):
raise Exception(
"'argilla_transformation_object' must be a dictionary, to log your payload to Argilla."
)
for v in argilla_transformation_object.values():
if v not in SUPPORTED_PAYLOAD_FIELDS:
raise Exception(
f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key."
)
def get_credentials_from_env(
self,
argilla_api_key: Optional[str],
argilla_dataset_name: Optional[str],
argilla_base_url: Optional[str],
) -> ArgillaCredentialsObject:
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
if _credentials_api_key is None:
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")
_credentials_base_url = (
argilla_base_url
or os.getenv("ARGILLA_BASE_URL")
or "http://localhost:6900/"
)
if _credentials_base_url is None:
raise Exception(
"Invalid Argilla Base URL given. _credentials_base_url=None."
)
_credentials_dataset_name = (
argilla_dataset_name
or os.getenv("ARGILLA_DATASET_NAME")
or "litellm-completion"
)
if _credentials_dataset_name is None:
raise Exception("Invalid Argilla Dataset give. Value=None.")
else:
dataset_response = litellm.module_level_client.get(
url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}",
headers={"X-Argilla-Api-Key": _credentials_api_key},
)
json_response = dataset_response.json()
if (
"items" in json_response
and isinstance(json_response["items"], list)
and len(json_response["items"]) > 0
):
_credentials_dataset_name = json_response["items"][0]["id"]
return ArgillaCredentialsObject(
ARGILLA_API_KEY=_credentials_api_key,
ARGILLA_BASE_URL=_credentials_base_url,
ARGILLA_DATASET_NAME=_credentials_dataset_name,
)
def get_chat_messages(
self, payload: StandardLoggingPayload
) -> List[Dict[str, Any]]:
payload_messages = payload.get("messages", None)
if payload_messages is None:
raise Exception("No chat messages found in payload.")
if (
isinstance(payload_messages, list)
and len(payload_messages) > 0
and isinstance(payload_messages[0], dict)
):
return payload_messages
elif isinstance(payload_messages, dict):
return [payload_messages]
else:
raise Exception(f"Invalid chat messages format: {payload_messages}")
def get_str_response(self, payload: StandardLoggingPayload) -> str:
response = payload["response"]
if response is None:
raise Exception("No response found in payload.")
if isinstance(response, str):
return response
elif isinstance(response, dict):
return (
response.get("choices", [{}])[0].get("message", {}).get("content", "")
)
else:
raise Exception(f"Invalid response format: {response}")
def _prepare_log_data(
self, kwargs, response_obj, start_time, end_time
) -> ArgillaItem:
try:
# Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if payload is None:
raise Exception("Error logging request payload. Payload=none.")
argilla_message = self.get_chat_messages(payload)
argilla_response = self.get_str_response(payload)
argilla_item: ArgillaItem = {"fields": {}}
for k, v in self.argilla_transformation_object.items():
if v == "messages":
argilla_item["fields"][k] = argilla_message
elif v == "response":
argilla_item["fields"][k] = argilla_response
else:
argilla_item["fields"][k] = payload.get(v, None)
return argilla_item
except Exception:
raise
def _send_batch(self):
if not self.log_queue:
return
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
headers = {"X-Argilla-Api-Key": argilla_api_key}
try:
response = requests.post(
url=url,
json=self.log_queue,
headers=headers,
)
if response.status_code >= 300:
verbose_logger.error(
f"Argilla Error: {response.status_code} - {response.text}"
)
else:
verbose_logger.debug(
f"Batch of {len(self.log_queue)} runs successfully created"
)
self.log_queue.clear()
except Exception:
verbose_logger.exception("Argilla Layer Error - Error sending batch.")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
sampling_rate = (
float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
else 1.0
)
random_sample = random.random()
if random_sample > sampling_rate:
verbose_logger.info(
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
sampling_rate, random_sample
)
)
return # Skip logging
verbose_logger.debug(
"Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
verbose_logger.debug(
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
)
if len(self.log_queue) >= self.batch_size:
self._send_batch()
except Exception:
verbose_logger.exception("Langsmith Layer Error - log_success_event error")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
sampling_rate = self.sampling_rate
random_sample = random.random()
if random_sample > sampling_rate:
verbose_logger.info(
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
sampling_rate, random_sample
)
)
return # Skip logging
verbose_logger.debug(
"Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
len(self.log_queue),
self.batch_size,
)
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except Exception:
verbose_logger.exception(
"Argilla Layer Error - error logging async success event."
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
sampling_rate = self.sampling_rate
random_sample = random.random()
if random_sample > sampling_rate:
verbose_logger.info(
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
sampling_rate, random_sample
)
)
return # Skip logging
verbose_logger.info("Langsmith Failure Event Logging!")
try:
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
len(self.log_queue),
self.batch_size,
)
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except Exception:
verbose_logger.exception(
"Langsmith Layer Error - error logging async failure event."
)
async def async_send_batch(self):
"""
sends runs to /batch endpoint
Sends runs from self.log_queue
Returns: None
Raises: Does not raise an exception, will only verbose_logger.exception()
"""
if not self.log_queue:
return
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
headers = {"X-Argilla-Api-Key": argilla_api_key}
try:
response = await self.async_httpx_client.put(
url=url,
data=json.dumps(
{
"items": self.log_queue,
}
),
headers=headers,
timeout=60000,
)
response.raise_for_status()
if response.status_code >= 300:
verbose_logger.error(
f"Argilla Error: {response.status_code} - {response.text}"
)
else:
verbose_logger.debug(
"Batch of %s runs successfully created", len(self.log_queue)
)
except httpx.HTTPStatusError:
verbose_logger.exception("Argilla HTTP Error")
except Exception:
verbose_logger.exception("Argilla Layer Error")

View file

@ -59,6 +59,7 @@ from litellm.utils import (
)
from ..integrations.aispend import AISpendLogger
from ..integrations.argilla import ArgillaLogger
from ..integrations.athina import AthinaLogger
from ..integrations.berrispend import BerriSpendLogger
from ..integrations.braintrust_logging import BraintrustLogger
@ -2339,6 +2340,14 @@ def _init_custom_logger_compatible_class(
_langsmith_logger = LangsmithLogger()
_in_memory_loggers.append(_langsmith_logger)
return _langsmith_logger # type: ignore
elif logging_integration == "argilla":
for callback in _in_memory_loggers:
if isinstance(callback, ArgillaLogger):
return callback # type: ignore
_argilla_logger = ArgillaLogger()
_in_memory_loggers.append(_argilla_logger)
return _argilla_logger # type: ignore
elif logging_integration == "literalai":
for callback in _in_memory_loggers:
if isinstance(callback, LiteralAILogger):
@ -2521,6 +2530,10 @@ def get_custom_logger_compatible_class(
for callback in _in_memory_loggers:
if isinstance(callback, LangsmithLogger):
return callback
elif logging_integration == "argilla":
for callback in _in_memory_loggers:
if isinstance(callback, ArgillaLogger):
return callback
elif logging_integration == "literalai":
for callback in _in_memory_loggers:
if isinstance(callback, LiteralAILogger):

View file

@ -163,10 +163,11 @@ class AsyncHTTPHandler:
try:
if timeout is None:
timeout = self.timeout
req = self.client.build_request(
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
)
response = await self.client.send(req, stream=stream)
response = await self.client.send(req)
response.raise_for_status()
return response
except (httpx.RemoteProtocolError, httpx.ConnectError):

View file

@ -3,11 +3,25 @@ Common utility functions used for translating messages across providers
"""
import json
from typing import Dict, List
from copy import deepcopy
from typing import Dict, List, Literal, Optional
from litellm.types.llms.openai import AllMessageValues
import litellm
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionUserMessage,
)
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
content="Please continue.", role="user"
)
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage(
content="Please continue.", role="assistant"
)
def convert_content_list_to_str(message: AllMessageValues) -> str:
"""
@ -69,3 +83,145 @@ def get_content_from_model_response(response: ModelResponse) -> str:
elif isinstance(choice, StreamingChoices):
content += getattr(choice, "delta", {}).get("content", "") or ""
return content
def detect_first_expected_role(
messages: List[AllMessageValues],
) -> Optional[Literal["user", "assistant"]]:
"""
Detect the first expected role based on the message sequence.
Rules:
1. If messages list is empty, assume 'user' starts
2. If first message is from assistant, expect 'user' next
3. If first message is from user, expect 'assistant' next
4. If first message is system, look at the next non-system message
Returns:
str: Either 'user' or 'assistant'
None: If no 'user' or 'assistant' messages provided
"""
if not messages:
return "user"
for message in messages:
if message["role"] == "system":
continue
return "user" if message["role"] == "assistant" else "assistant"
return None
def _insert_user_continue_message(
messages: List[AllMessageValues],
user_continue_message: Optional[ChatCompletionUserMessage],
ensure_alternating_roles: bool,
) -> List[AllMessageValues]:
"""
Inserts a user continue message into the messages list.
Handles three cases:
1. Initial assistant message
2. Final assistant message
3. Consecutive assistant messages
Only inserts messages between consecutive assistant messages,
ignoring all other role types.
"""
if not messages:
return messages
result_messages = messages.copy() # Don't modify the input list
continue_message = user_continue_message or DEFAULT_USER_CONTINUE_MESSAGE
# Handle first message if it's an assistant message
if result_messages[0]["role"] == "assistant":
result_messages.insert(0, continue_message)
# Handle consecutive assistant messages and final message
i = 1 # Start from second message since we handled first message
while i < len(result_messages):
curr_message = result_messages[i]
prev_message = result_messages[i - 1]
# Only check for consecutive assistant messages
# Ignore all other role types
if curr_message["role"] == "assistant" and prev_message["role"] == "assistant":
result_messages.insert(i, continue_message)
i += 2 # Skip over the message we just inserted
else:
i += 1
# Handle final message
if result_messages[-1]["role"] == "assistant" and ensure_alternating_roles:
result_messages.append(continue_message)
return result_messages
def _insert_assistant_continue_message(
messages: List[AllMessageValues],
assistant_continue_message: Optional[ChatCompletionAssistantMessage] = None,
ensure_alternating_roles: bool = True,
) -> List[AllMessageValues]:
"""
Add assistant continuation messages between consecutive user messages.
Args:
messages: List of message dictionaries
assistant_continue_message: Optional custom assistant message
ensure_alternating_roles: Whether to enforce alternating roles
Returns:
Modified list of messages with inserted assistant messages
"""
if not ensure_alternating_roles or len(messages) <= 1:
return messages
# Create a new list to store modified messages
modified_messages: List[AllMessageValues] = []
for i, message in enumerate(messages):
modified_messages.append(message)
# Check if we need to insert an assistant message
if (
i < len(messages) - 1 # Not the last message
and message.get("role") == "user" # Current is user
and messages[i + 1].get("role") == "user"
): # Next is user
# Insert assistant message
continue_message = (
assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE
)
modified_messages.append(continue_message)
return modified_messages
def get_completion_messages(
messages: List[AllMessageValues],
assistant_continue_message: Optional[ChatCompletionAssistantMessage],
user_continue_message: Optional[ChatCompletionUserMessage],
ensure_alternating_roles: bool,
) -> List[AllMessageValues]:
"""
Ensures messages alternate between user and assistant roles by adding placeholders
only when there are consecutive messages of the same role.
1. ensure 'user' message before 1st 'assistant' message
2. ensure 'user' message after last 'assistant' message
"""
if not ensure_alternating_roles:
return messages.copy()
## INSERT USER CONTINUE MESSAGE
messages = _insert_user_continue_message(
messages, user_continue_message, ensure_alternating_roles
)
## INSERT ASSISTANT CONTINUE MESSAGE
messages = _insert_assistant_continue_message(
messages, assistant_continue_message, ensure_alternating_roles
)
return messages

View file

@ -1388,9 +1388,14 @@ def anthropic_messages_pt(
for m in user_message_types_block["content"]:
if m.get("type", "") == "image_url":
m = cast(ChatCompletionImageObject, m)
image_chunk = convert_to_anthropic_image_obj(
openai_image_url=m["image_url"]["url"] # type: ignore
)
if isinstance(m["image_url"], str):
image_chunk = convert_to_anthropic_image_obj(
openai_image_url=m["image_url"]
)
else:
image_chunk = convert_to_anthropic_image_obj(
openai_image_url=m["image_url"]["url"]
)
_anthropic_content_element = AnthropicMessagesImageParam(
type="image",

View file

@ -887,10 +887,16 @@ class VertexLLM(VertexBase):
if "citationMetadata" in candidate:
citation_metadata.append(candidate["citationMetadata"])
if "text" in candidate["content"]["parts"][0]:
if (
"parts" in candidate["content"]
and "text" in candidate["content"]["parts"][0]
):
content_str = candidate["content"]["parts"][0]["text"]
if "functionCall" in candidate["content"]["parts"][0]:
if (
"parts" in candidate["content"]
and "functionCall" in candidate["content"]["parts"][0]
):
_function_chunk = ChatCompletionToolCallFunctionChunk(
name=candidate["content"]["parts"][0]["functionCall"][
"name"
@ -1358,7 +1364,11 @@ class ModelResponseIterator:
if _candidates and len(_candidates) > 0:
gemini_chunk = _candidates[0]
if gemini_chunk and "content" in gemini_chunk:
if (
gemini_chunk
and "content" in gemini_chunk
and "parts" in gemini_chunk["content"]
):
if "text" in gemini_chunk["content"]["parts"][0]:
text = gemini_chunk["content"]["parts"][0]["text"]
elif "functionCall" in gemini_chunk["content"]["parts"][0]:

View file

@ -109,6 +109,7 @@ from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion
from .llms.prompt_templates.common_utils import get_completion_messages
from .llms.prompt_templates.factory import (
custom_prompt,
function_call_prompt,
@ -144,7 +145,11 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler im
VertexEmbedding,
)
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
from .types.llms.openai import (
ChatCompletionAssistantMessage,
ChatCompletionUserMessage,
HttpxBinaryResponseContent,
)
from .types.utils import (
AdapterCompletionStreamWrapper,
ChatCompletionMessageToolCall,
@ -748,6 +753,15 @@ def completion( # type: ignore
proxy_server_request = kwargs.get("proxy_server_request", None)
fallbacks = kwargs.get("fallbacks", None)
headers = kwargs.get("headers", None) or extra_headers
ensure_alternating_roles: Optional[bool] = kwargs.get(
"ensure_alternating_roles", None
)
user_continue_message: Optional[ChatCompletionUserMessage] = kwargs.get(
"user_continue_message", None
)
assistant_continue_message: Optional[ChatCompletionAssistantMessage] = kwargs.get(
"assistant_continue_message", None
)
if headers is None:
headers = {}
@ -784,7 +798,12 @@ def completion( # type: ignore
### Admin Controls ###
no_log = kwargs.get("no-log", False)
### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489
messages = deepcopy(messages)
messages = get_completion_messages(
messages=messages,
ensure_alternating_roles=ensure_alternating_roles or False,
user_continue_message=user_continue_message,
assistant_continue_message=assistant_continue_message,
)
######## end of unpacking kwargs ###########
openai_params = [
"functions",

View file

@ -5,8 +5,8 @@ model_list:
api_key: os.environ/OPENAI_API_KEY
assistant_settings:
custom_llm_provider: azure
litellm_params:
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
litellm_settings:
callbacks: ["argilla"]
argilla_transformation_object:
user_input: "messages"
llm_output: "response"

View file

@ -1,13 +1,17 @@
import re
import sys
import traceback
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple
from fastapi import HTTPException, Request, status
from litellm import Router, provider_list
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.types.router import (
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
ConfigurableClientsideParamsCustomAuth,
)
def _get_request_ip_address(
@ -73,8 +77,41 @@ def check_complete_credentials(request_body: dict) -> bool:
return False
def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
"""
Check if request_body_value matches the regex_str or is equal to param
"""
if re.match(regex_str, request_body_value) or regex_str == request_body_value:
return True
return False
def _is_param_allowed(
param: str,
request_body_value: Any,
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
) -> bool:
"""
Check if param is a str or dict and if request_body_value is in the list of allowed values
"""
if configurable_clientside_auth_params is None:
return False
for item in configurable_clientside_auth_params:
if isinstance(item, str) and param == item:
return True
elif isinstance(item, Dict):
if param == "api_base" and check_regex_or_str_match(
request_body_value=request_body_value,
regex_str=item["api_base"],
): # assume param is a regex
return True
return False
def _allow_model_level_clientside_configurable_parameters(
model: str, param: str, llm_router: Optional[Router]
model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
) -> bool:
"""
Check if model is allowed to use configurable client-side params
@ -99,10 +136,11 @@ def _allow_model_level_clientside_configurable_parameters(
if model_info is None or model_info.configurable_clientside_auth_params is None:
return False
if param in model_info.configurable_clientside_auth_params:
return True
return False
return _is_param_allowed(
param=param,
request_body_value=request_body_value,
configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
)
def is_request_body_safe(
@ -127,7 +165,10 @@ def is_request_body_safe(
return True
elif (
_allow_model_level_clientside_configurable_parameters(
model=model, param=param, llm_router=llm_router
model=model,
param=param,
request_body_value=request_body[param],
llm_router=llm_router,
)
is True
):

View file

@ -0,0 +1,24 @@
from fastapi import HTTPException
from litellm.proxy._types import LitellmUserRoles
def check_is_admin_only_access(ui_access_mode: str) -> bool:
"""Checks ui access mode is admin_only"""
return ui_access_mode == "admin_only"
def has_admin_ui_access(user_role: str) -> bool:
"""
Check if the user has admin access to the UI.
Returns:
bool: True if user is 'proxy_admin' or 'proxy_admin_view_only', False otherwise.
"""
if (
user_role != LitellmUserRoles.PROXY_ADMIN.value
and user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
):
return False
return True

View file

@ -31,6 +31,10 @@ from litellm.proxy.common_utils.admin_ui_utils import (
show_missing_vars_in_env,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
from litellm.proxy.management_endpoints.sso_helper_utils import (
check_is_admin_only_access,
has_admin_ui_access,
)
from litellm.secret_managers.main import str_to_bool
if TYPE_CHECKING:
@ -545,17 +549,16 @@ async def auth_callback(request: Request):
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
)
## CHECK IF ROLE ALLOWED TO USE PROXY ##
if ui_access_mode == "admin_only" and (
user_role != LitellmUserRoles.PROXY_ADMIN.value
or user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
):
verbose_proxy_logger.debug("EXCEPTION RAISED")
raise HTTPException(
status_code=401,
detail={
"error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}"
},
)
is_admin_only_access = check_is_admin_only_access(ui_access_mode)
if is_admin_only_access:
has_access = has_admin_ui_access(user_role)
if not has_access:
raise HTTPException(
status_code=401,
detail={
"error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}"
},
)
import jwt

View file

@ -1124,6 +1124,7 @@ class PrismaClient:
"MonthlyGlobalSpendPerKey",
"MonthlyGlobalSpendPerUserPerKey",
"Last30dTopEndUsersSpend",
"DailyTagSpend",
]
required_view = "LiteLLM_VerificationTokenView"
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)

View file

@ -104,6 +104,7 @@ from litellm.types.llms.openai import (
Thread,
)
from litellm.types.router import (
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
SPECIAL_MODEL_INFO_PARAMS,
VALID_LITELLM_ENVIRONMENTS,
AlertingConfig,
@ -4183,7 +4184,7 @@ class Router:
total_tpm: Optional[int] = None
total_rpm: Optional[int] = None
configurable_clientside_auth_params: Optional[List[str]] = None
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
for model in self.model_list:
is_match = False

View file

@ -65,7 +65,7 @@ class HttpxPartType(TypedDict, total=False):
class HttpxContentType(TypedDict, total=False):
role: Literal["user", "model"]
parts: Required[List[HttpxPartType]]
parts: List[HttpxPartType]
class ContentType(TypedDict, total=False):

View file

@ -5,10 +5,11 @@ litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
import datetime
import enum
import uuid
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import httpx
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import TypedDict
from ..exceptions import RateLimitError
from .completion import CompletionRequest
@ -16,6 +17,15 @@ from .embedding import EmbeddingRequest
from .utils import ModelResponse
class ConfigurableClientsideParamsCustomAuth(TypedDict):
api_base: str
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = Optional[
List[Union[str, ConfigurableClientsideParamsCustomAuth]]
]
class ModelConfig(BaseModel):
model_name: str
litellm_params: Union[CompletionRequest, EmbeddingRequest]
@ -139,7 +149,7 @@ class GenericLiteLLMParams(BaseModel):
)
max_retries: Optional[int] = None
organization: Optional[str] = None # for openai orgs
configurable_clientside_auth_params: Optional[List[str]] = None
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None
## VERTEX AI ##
@ -311,9 +321,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
stream_timeout: Optional[Union[float, str]]
max_retries: Optional[int]
organization: Optional[Union[List, str]] # for openai orgs
configurable_clientside_auth_params: Optional[
List[str]
] # for allowing api base switching on finetuned models
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS # for allowing api base switching on finetuned models
## DROP PARAMS ##
drop_params: Optional[bool]
## UNIFIED PROJECT/REGION ##
@ -496,7 +504,7 @@ class ModelGroupInfo(BaseModel):
supports_vision: bool = Field(default=False)
supports_function_calling: bool = Field(default=False)
supported_openai_params: Optional[List[str]] = Field(default=[])
configurable_clientside_auth_params: Optional[List[str]] = None
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
class AssistantsTypedDict(TypedDict):

View file

@ -1246,6 +1246,9 @@ all_litellm_params = [
"user_continue_message",
"configurable_clientside_auth_params",
"weight",
"ensure_alternating_roles",
"assistant_continue_message",
"user_continue_message",
]

View file

@ -0,0 +1,70 @@
# What is this?
## Tests if proxy/auth/auth_utils.py works as expected
import sys, os, asyncio, time, random, uuid
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm.proxy.auth.auth_utils import (
_allow_model_level_clientside_configurable_parameters,
)
from litellm.router import Router
@pytest.mark.parametrize(
"allowed_param, input_value, should_return_true",
[
("api_base", {"api_base": "http://dummy.com"}, True),
(
{"api_base": "https://api.openai.com/v1"},
{"api_base": "https://api.openai.com/v1"},
True,
), # should return True
(
{"api_base": "https://api.openai.com/v1"},
{"api_base": "https://api.anthropic.com/v1"},
False,
), # should return False
(
{"api_base": "^https://litellm.*direct\.fireworks\.ai/v1$"},
{"api_base": "https://litellm-dev.direct.fireworks.ai/v1"},
True,
),
(
{"api_base": "^https://litellm.*novice\.fireworks\.ai/v1$"},
{"api_base": "https://litellm-dev.direct.fireworks.ai/v1"},
False,
),
],
)
def test_configurable_clientside_parameters(
allowed_param, input_value, should_return_true
):
router = Router(
model_list=[
{
"model_name": "dummy-model",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "dummy-key",
"configurable_clientside_auth_params": [allowed_param],
},
}
]
)
resp = _allow_model_level_clientside_configurable_parameters(
model="dummy-model",
param="api_base",
request_body_value=input_value["api_base"],
llm_router=router,
)
print(resp)
assert resp == should_return_true

View file

@ -22,9 +22,13 @@ from litellm.llms.prompt_templates.factory import (
llama_2_chat_pt,
prompt_factory,
)
from litellm.llms.prompt_templates.common_utils import (
get_completion_messages,
)
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
_gemini_convert_messages_with_history,
)
from unittest.mock import AsyncMock, MagicMock, patch
def test_llama_3_prompt():
@ -457,3 +461,217 @@ def test_azure_tool_call_invoke_helper():
"function_call": {"name": "get_weather", "arguments": ""},
},
]
@pytest.mark.parametrize(
"messages, expected_messages, user_continue_message, assistant_continue_message",
[
(
[
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{"role": "user", "content": "What is Databricks?"},
{"role": "user", "content": "What is Azure?"},
{"role": "assistant", "content": "I don't know anyything, do you?"},
],
[
{"role": "user", "content": "Hello!"},
{
"role": "assistant",
"content": "Hello! How can I assist you today?",
},
{"role": "user", "content": "What is Databricks?"},
{
"role": "assistant",
"content": "Please continue.",
},
{"role": "user", "content": "What is Azure?"},
{
"role": "assistant",
"content": "I don't know anyything, do you?",
},
{
"role": "user",
"content": "Please continue.",
},
],
None,
None,
),
(
[
{"role": "user", "content": "Hello!"},
],
[
{"role": "user", "content": "Hello!"},
],
None,
None,
),
(
[
{"role": "user", "content": "Hello!"},
{"role": "user", "content": "What is Databricks?"},
],
[
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Please continue."},
{"role": "user", "content": "What is Databricks?"},
],
None,
None,
),
(
[
{"role": "user", "content": "Hello!"},
{"role": "user", "content": "What is Databricks?"},
{"role": "user", "content": "What is Azure?"},
],
[
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Please continue."},
{"role": "user", "content": "What is Databricks?"},
{
"role": "assistant",
"content": "Please continue.",
},
{"role": "user", "content": "What is Azure?"},
],
None,
None,
),
(
[
{"role": "user", "content": "Hello!"},
{
"role": "assistant",
"content": "Hello! How can I assist you today?",
},
{"role": "user", "content": "What is Databricks?"},
{"role": "user", "content": "What is Azure?"},
{"role": "assistant", "content": "I don't know anyything, do you?"},
{"role": "assistant", "content": "I can't repeat sentences."},
],
[
{"role": "user", "content": "Hello!"},
{
"role": "assistant",
"content": "Hello! How can I assist you today?",
},
{"role": "user", "content": "What is Databricks?"},
{
"role": "assistant",
"content": "Please continue",
},
{"role": "user", "content": "What is Azure?"},
{
"role": "assistant",
"content": "I don't know anyything, do you?",
},
{
"role": "user",
"content": "Ok",
},
{
"role": "assistant",
"content": "I can't repeat sentences.",
},
{"role": "user", "content": "Ok"},
],
{
"role": "user",
"content": "Ok",
},
{
"role": "assistant",
"content": "Please continue",
},
),
],
)
def test_ensure_alternating_roles(
messages, expected_messages, user_continue_message, assistant_continue_message
):
messages = get_completion_messages(
messages=messages,
assistant_continue_message=assistant_continue_message,
user_continue_message=user_continue_message,
ensure_alternating_roles=True,
)
print(messages)
assert messages == expected_messages
def test_alternating_roles_e2e():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
import json
litellm.set_verbose = True
http_handler = HTTPHandler()
with patch.object(http_handler, "post", new=MagicMock()) as mock_post:
response = litellm.completion(
**{
"model": "databricks/databricks-meta-llama-3-1-70b-instruct",
"messages": [
{"role": "user", "content": "Hello!"},
{
"role": "assistant",
"content": "Hello! How can I assist you today?",
},
{"role": "user", "content": "What is Databricks?"},
{"role": "user", "content": "What is Azure?"},
{"role": "assistant", "content": "I don't know anyything, do you?"},
{"role": "assistant", "content": "I can't repeat sentences."},
],
"user_continue_message": {
"role": "user",
"content": "Ok",
},
"assistant_continue_message": {
"role": "assistant",
"content": "Please continue",
},
"ensure_alternating_roles": True,
},
client=http_handler,
)
print(f"response: {response}")
assert mock_post.call_args.kwargs["data"] == json.dumps(
{
"model": "databricks-meta-llama-3-1-70b-instruct",
"messages": [
{"role": "user", "content": "Hello!"},
{
"role": "assistant",
"content": "Hello! How can I assist you today?",
},
{"role": "user", "content": "What is Databricks?"},
{
"role": "assistant",
"content": "Please continue",
},
{"role": "user", "content": "What is Azure?"},
{
"role": "assistant",
"content": "I don't know anyything, do you?",
},
{
"role": "user",
"content": "Ok",
},
{
"role": "assistant",
"content": "I can't repeat sentences.",
},
{
"role": "user",
"content": "Ok",
},
],
"stream": False,
}
)

View file

@ -0,0 +1,38 @@
# What is this?
## This tests the batch update spend logic on the proxy server
import asyncio
import os
import random
import sys
import time
import traceback
from datetime import datetime
from dotenv import load_dotenv
from fastapi import Request
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import logging
from litellm.proxy.management_endpoints.sso_helper_utils import (
check_is_admin_only_access,
has_admin_ui_access,
)
from litellm.proxy._types import LitellmUserRoles
def test_check_is_admin_only_access():
assert check_is_admin_only_access("admin_only") is True
assert check_is_admin_only_access("user_only") is False
def test_has_admin_ui_access():
assert has_admin_ui_access(LitellmUserRoles.PROXY_ADMIN.value) is True
assert has_admin_ui_access(LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value) is True
assert has_admin_ui_access(LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value) is False