From f252350881534e200f974778ef2197f289eb9ff3 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Thu, 17 Oct 2024 22:09:11 -0700 Subject: [PATCH] LiteLLM Minor Fixes & Improvements (10/17/2024) (#6293) * fix(ui_sso.py): fix faulty admin only check Fixes https://github.com/BerriAI/litellm/issues/6286 * refactor(sso_helper_utils.py): refactor /sso/callback to use helper utils, covered by unit testing Prevent future regressions * feat(prompt_factory): support 'ensure_alternating_roles' param Closes https://github.com/BerriAI/litellm/issues/6257 * fix(proxy/utils.py): add dailytagspend to expected views * feat(auth_utils.py): support setting regex for clientside auth credentials Fixes https://github.com/BerriAI/litellm/issues/6203 * build(cookbook): add tutorial for mlflow + langchain + litellm proxy tracing * feat(argilla.py): add argilla logging integration Closes https://github.com/BerriAI/litellm/issues/6201 * fix: fix linting errors * fix: fix ruff error * test: fix test * fix: update vertex ai assumption - parts not always guaranteed (#6296) * docs(configs.md): add argila env var to docs --- ...flow_langchain_tracing_litellm_proxy.ipynb | 312 +++++++++++++ docs/my-website/docs/proxy/configs.md | 5 + litellm/__init__.py | 3 + litellm/cost_calculator.py | 6 +- litellm/integrations/argilla.py | 410 ++++++++++++++++++ litellm/litellm_core_utils/litellm_logging.py | 13 + litellm/llms/custom_httpx/http_handler.py | 3 +- litellm/llms/prompt_templates/common_utils.py | 160 ++++++- litellm/llms/prompt_templates/factory.py | 11 +- .../vertex_and_google_ai_studio_gemini.py | 16 +- litellm/main.py | 23 +- litellm/proxy/_new_secret_config.yaml | 10 +- litellm/proxy/auth/auth_utils.py | 55 ++- .../management_endpoints/sso_helper_utils.py | 24 + litellm/proxy/management_endpoints/ui_sso.py | 25 +- litellm/proxy/utils.py | 1 + litellm/router.py | 3 +- litellm/types/llms/vertex_ai.py | 2 +- litellm/types/router.py | 20 +- litellm/types/utils.py | 3 + tests/local_testing/test_auth_utils.py | 70 +++ tests/local_testing/test_prompt_factory.py | 218 ++++++++++ .../local_testing/test_ui_sso_helper_utils.py | 38 ++ 23 files changed, 1388 insertions(+), 43 deletions(-) create mode 100644 cookbook/mlflow_langchain_tracing_litellm_proxy.ipynb create mode 100644 litellm/integrations/argilla.py create mode 100644 litellm/proxy/management_endpoints/sso_helper_utils.py create mode 100644 tests/local_testing/test_auth_utils.py create mode 100644 tests/local_testing/test_ui_sso_helper_utils.py diff --git a/cookbook/mlflow_langchain_tracing_litellm_proxy.ipynb b/cookbook/mlflow_langchain_tracing_litellm_proxy.ipynb new file mode 100644 index 000000000..0c684942f --- /dev/null +++ b/cookbook/mlflow_langchain_tracing_litellm_proxy.ipynb @@ -0,0 +1,312 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Databricks Notebook with MLFlow AutoLogging for LiteLLM Proxy calls\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "5e2812ed-8000-4793-b090-49a31464d810", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip install -U -qqqq databricks-agents mlflow langchain==0.3.1 langchain-core==0.3.6 " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "52530b37-1860-4bba-a6c1-723de83bc58f", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip install \"langchain-openai<=0.3.1\"" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "43c6f4b1-e2d5-431c-b1a2-b97df7707d59", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "# Before logging this chain using the driver notebook, you must comment out this line.\n", + "dbutils.library.restartPython() " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "88eb8dd7-16b1-480b-aa70-cd429ef87159", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "import mlflow\n", + "from operator import itemgetter\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.prompts import PromptTemplate\n", + "from langchain_core.runnables import RunnableLambda\n", + "from langchain_databricks import ChatDatabricks\n", + "from langchain_openai import ChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "f0fdca8f-6f6f-407c-ad4a-0d5a2778728e", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "import mlflow\n", + "mlflow.langchain.autolog()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "2ef67315-e468-4d60-a318-98c2cac75bc4", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "# These helper functions parse the `messages` array.\n", + "\n", + "# Return the string contents of the most recent message from the user\n", + "def extract_user_query_string(chat_messages_array):\n", + " return chat_messages_array[-1][\"content\"]\n", + "\n", + "\n", + "# Return the chat history, which is is everything before the last question\n", + "def extract_chat_history(chat_messages_array):\n", + " return chat_messages_array[:-1]" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "17708467-1976-48bd-94a0-8c7895cfae3b", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "model = ChatOpenAI(\n", + " openai_api_base=\"LITELLM_PROXY_BASE_URL\", # e.g.: http://0.0.0.0:4000\n", + " model = \"gpt-3.5-turbo\", # LITELLM 'model_name'\n", + " temperature=0.1, \n", + " api_key=\"LITELLM_PROXY_API_KEY\" # e.g.: \"sk-1234\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "a5f2c2af-82f7-470d-b559-47b67fb00cda", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "############\n", + "# Prompt Template for generation\n", + "############\n", + "prompt = PromptTemplate(\n", + " template=\"You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}\",\n", + " input_variables=[\"question\"],\n", + ")\n", + "\n", + "############\n", + "# FM for generation\n", + "# ChatDatabricks accepts any /llm/v1/chat model serving endpoint\n", + "############\n", + "model = ChatDatabricks(\n", + " endpoint=\"databricks-dbrx-instruct\",\n", + " extra_params={\"temperature\": 0.01, \"max_tokens\": 500},\n", + ")\n", + "\n", + "\n", + "############\n", + "# Simple chain\n", + "############\n", + "# The framework requires the chain to return a string value.\n", + "chain = (\n", + " {\n", + " \"question\": itemgetter(\"messages\")\n", + " | RunnableLambda(extract_user_query_string),\n", + " \"chat_history\": itemgetter(\"messages\") | RunnableLambda(extract_chat_history),\n", + " }\n", + " | prompt\n", + " | model\n", + " | StrOutputParser()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "366edd90-62a1-4d6f-8a65-0211fb24ca02", + "showTitle": false, + "title": "" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'Hello there! I\\'m here to help with your questions. Regarding your query about \"rag,\" it\\'s not something typically associated with a \"hello world\" bot, but I\\'m happy to explain!\\n\\nRAG, or Remote Angular GUI, is a tool that allows you to create and manage Angular applications remotely. It\\'s a way to develop and test Angular components and applications without needing to set up a local development environment. This can be particularly useful for teams working on distributed systems or for developers who prefer to work in a cloud-based environment.\\n\\nI hope this explanation of RAG has been helpful and interesting! If you have any other questions or need further clarification, feel free to ask.'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "application/databricks.mlflow.trace": "\"tr-ea2226413395413ba2cf52cffc523502\"", + "text/plain": [ + "Trace(request_id=tr-ea2226413395413ba2cf52cffc523502)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# This is the same input your chain's REST API will accept.\n", + "question = {\n", + " \"messages\": [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"what is rag?\",\n", + " },\n", + " ]\n", + "}\n", + "\n", + "chain.invoke(question)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "5d68e37d-0980-4a02-bf8d-885c3853f6c1", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "mlflow.models.set_model(model=model)" + ] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "dashboards": [], + "environmentMetadata": null, + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 4 + }, + "notebookName": "Untitled Notebook 2024-10-16 19:35:16", + "widgets": {} + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 11ba0427b..67ad60c28 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -873,6 +873,11 @@ router_settings: | ALLOWED_EMAIL_DOMAINS | List of email domains allowed for access | ARIZE_API_KEY | API key for Arize platform integration | ARIZE_SPACE_KEY | Space key for Arize platform +| ARGILLA_BATCH_SIZE | Batch size for Argilla logging +| ARGILLA_API_KEY | API key for Argilla platform +| ARGILLA_SAMPLING_RATE | Sampling rate for Argilla logging +| ARGILLA_DATASET_NAME | Dataset name for Argilla logging +| ARGILLA_BASE_URL | Base URL for Argilla service | ATHINA_API_KEY | API key for Athina service | AUTH_STRATEGY | Strategy used for authentication (e.g., OAuth, API key) | AWS_ACCESS_KEY_ID | Access Key ID for AWS services diff --git a/litellm/__init__.py b/litellm/__init__.py index ac44c40d2..701d7b23b 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -54,6 +54,7 @@ _custom_logger_compatible_callbacks_literal = Literal[ "langtrace", "gcs_bucket", "opik", + "argilla", ] _known_custom_logger_compatible_callbacks: List = list( get_args(_custom_logger_compatible_callbacks_literal) @@ -61,6 +62,8 @@ _known_custom_logger_compatible_callbacks: List = list( callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = [] langfuse_default_tags: Optional[List[str]] = None langsmith_batch_size: Optional[int] = None +argilla_batch_size: Optional[int] = None +argilla_transformation_object: Optional[Dict[str, Any]] = None _async_input_callback: List[Callable] = ( [] ) # internal variable - async custom callbacks are routed here. diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index b893e6646..d86706f5b 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -693,7 +693,11 @@ def completion_cost( completion_response, RerankResponse ): meta_obj = completion_response.meta - billed_units = meta_obj.get("billed_units", {}) or {} + if meta_obj is not None: + billed_units = meta_obj.get("billed_units", {}) or {} + else: + billed_units = {} + search_units = ( billed_units.get("search_units") or 1 ) # cohere charges per request by default. diff --git a/litellm/integrations/argilla.py b/litellm/integrations/argilla.py new file mode 100644 index 000000000..5c0bd4b1e --- /dev/null +++ b/litellm/integrations/argilla.py @@ -0,0 +1,410 @@ +""" +Send logs to Argilla for annotation +""" + +import asyncio +import json +import os +import random +import time +import traceback +import types +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, TypedDict, Union + +import dotenv # type: ignore +import httpx +import requests # type: ignore +from pydantic import BaseModel # type: ignore + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.llms.prompt_templates.common_utils import get_content_from_model_response +from litellm.types.utils import StandardLoggingPayload + + +class LangsmithInputs(BaseModel): + model: Optional[str] = None + messages: Optional[List[Any]] = None + stream: Optional[bool] = None + call_type: Optional[str] = None + litellm_call_id: Optional[str] = None + completion_start_time: Optional[datetime] = None + temperature: Optional[float] = None + max_tokens: Optional[int] = None + custom_llm_provider: Optional[str] = None + input: Optional[List[Any]] = None + log_event_type: Optional[str] = None + original_response: Optional[Any] = None + response_cost: Optional[float] = None + + # LiteLLM Virtual Key specific fields + user_api_key: Optional[str] = None + user_api_key_user_id: Optional[str] = None + user_api_key_team_alias: Optional[str] = None + + +class ArgillaItem(TypedDict): + fields: Dict[str, Any] + + +class ArgillaPayload(TypedDict): + items: List[ArgillaItem] + + +class ArgillaCredentialsObject(TypedDict): + ARGILLA_API_KEY: str + ARGILLA_DATASET_NAME: str + ARGILLA_BASE_URL: str + + +SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"] + + +def is_serializable(value): + non_serializable_types = ( + types.CoroutineType, + types.FunctionType, + types.GeneratorType, + BaseModel, + ) + return not isinstance(value, non_serializable_types) + + +class ArgillaLogger(CustomBatchLogger): + def __init__( + self, + argilla_api_key: Optional[str] = None, + argilla_dataset_name: Optional[str] = None, + argilla_base_url: Optional[str] = None, + **kwargs, + ): + if litellm.argilla_transformation_object is None: + raise Exception( + "'litellm.argilla_transformation_object' is required, to log your payload to Argilla." + ) + self.validate_argilla_transformation_object( + litellm.argilla_transformation_object + ) + self.argilla_transformation_object = litellm.argilla_transformation_object + self.default_credentials = self.get_credentials_from_env( + argilla_api_key=argilla_api_key, + argilla_dataset_name=argilla_dataset_name, + argilla_base_url=argilla_base_url, + ) + self.sampling_rate: float = ( + float(os.getenv("ARGILLA_SAMPLING_RATE")) # type: ignore + if os.getenv("ARGILLA_SAMPLING_RATE") is not None + and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() # type: ignore + else 1.0 + ) + + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + _batch_size = ( + os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size + ) + if _batch_size: + self.batch_size = int(_batch_size) + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + super().__init__(**kwargs, flush_lock=self.flush_lock) + + def validate_argilla_transformation_object( + self, argilla_transformation_object: Dict[str, Any] + ): + if not isinstance(argilla_transformation_object, dict): + raise Exception( + "'argilla_transformation_object' must be a dictionary, to log your payload to Argilla." + ) + + for v in argilla_transformation_object.values(): + if v not in SUPPORTED_PAYLOAD_FIELDS: + raise Exception( + f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key." + ) + + def get_credentials_from_env( + self, + argilla_api_key: Optional[str], + argilla_dataset_name: Optional[str], + argilla_base_url: Optional[str], + ) -> ArgillaCredentialsObject: + + _credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY") + if _credentials_api_key is None: + raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.") + + _credentials_base_url = ( + argilla_base_url + or os.getenv("ARGILLA_BASE_URL") + or "http://localhost:6900/" + ) + if _credentials_base_url is None: + raise Exception( + "Invalid Argilla Base URL given. _credentials_base_url=None." + ) + + _credentials_dataset_name = ( + argilla_dataset_name + or os.getenv("ARGILLA_DATASET_NAME") + or "litellm-completion" + ) + if _credentials_dataset_name is None: + raise Exception("Invalid Argilla Dataset give. Value=None.") + else: + dataset_response = litellm.module_level_client.get( + url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}", + headers={"X-Argilla-Api-Key": _credentials_api_key}, + ) + json_response = dataset_response.json() + if ( + "items" in json_response + and isinstance(json_response["items"], list) + and len(json_response["items"]) > 0 + ): + _credentials_dataset_name = json_response["items"][0]["id"] + + return ArgillaCredentialsObject( + ARGILLA_API_KEY=_credentials_api_key, + ARGILLA_BASE_URL=_credentials_base_url, + ARGILLA_DATASET_NAME=_credentials_dataset_name, + ) + + def get_chat_messages( + self, payload: StandardLoggingPayload + ) -> List[Dict[str, Any]]: + payload_messages = payload.get("messages", None) + + if payload_messages is None: + raise Exception("No chat messages found in payload.") + + if ( + isinstance(payload_messages, list) + and len(payload_messages) > 0 + and isinstance(payload_messages[0], dict) + ): + return payload_messages + elif isinstance(payload_messages, dict): + return [payload_messages] + else: + raise Exception(f"Invalid chat messages format: {payload_messages}") + + def get_str_response(self, payload: StandardLoggingPayload) -> str: + response = payload["response"] + + if response is None: + raise Exception("No response found in payload.") + + if isinstance(response, str): + return response + elif isinstance(response, dict): + return ( + response.get("choices", [{}])[0].get("message", {}).get("content", "") + ) + else: + raise Exception(f"Invalid response format: {response}") + + def _prepare_log_data( + self, kwargs, response_obj, start_time, end_time + ) -> ArgillaItem: + try: + # Ensure everything in the payload is converted to str + payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + + if payload is None: + raise Exception("Error logging request payload. Payload=none.") + + argilla_message = self.get_chat_messages(payload) + argilla_response = self.get_str_response(payload) + argilla_item: ArgillaItem = {"fields": {}} + for k, v in self.argilla_transformation_object.items(): + if v == "messages": + argilla_item["fields"][k] = argilla_message + elif v == "response": + argilla_item["fields"][k] = argilla_response + else: + argilla_item["fields"][k] = payload.get(v, None) + return argilla_item + except Exception: + raise + + def _send_batch(self): + if not self.log_queue: + return + + argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"] + argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"] + + url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk" + + argilla_api_key = self.default_credentials["ARGILLA_API_KEY"] + + headers = {"X-Argilla-Api-Key": argilla_api_key} + + try: + response = requests.post( + url=url, + json=self.log_queue, + headers=headers, + ) + + if response.status_code >= 300: + verbose_logger.error( + f"Argilla Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.debug( + f"Batch of {len(self.log_queue)} runs successfully created" + ) + + self.log_queue.clear() + except Exception: + verbose_logger.exception("Argilla Layer Error - Error sending batch.") + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + sampling_rate = ( + float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore + if os.getenv("LANGSMITH_SAMPLING_RATE") is not None + and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore + else 1.0 + ) + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.debug( + "Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..." + ) + + if len(self.log_queue) >= self.batch_size: + self._send_batch() + + except Exception: + verbose_logger.exception("Langsmith Layer Error - log_success_event error") + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + sampling_rate = self.sampling_rate + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.debug( + "Langsmith Async Layer Logging - kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Langsmith logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Argilla Layer Error - error logging async success event." + ) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + sampling_rate = self.sampling_rate + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.info("Langsmith Failure Event Logging!") + try: + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Langsmith logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Langsmith Layer Error - error logging async failure event." + ) + + async def async_send_batch(self): + """ + sends runs to /batch endpoint + + Sends runs from self.log_queue + + Returns: None + + Raises: Does not raise an exception, will only verbose_logger.exception() + """ + if not self.log_queue: + return + + argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"] + argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"] + + url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk" + + argilla_api_key = self.default_credentials["ARGILLA_API_KEY"] + + headers = {"X-Argilla-Api-Key": argilla_api_key} + + try: + response = await self.async_httpx_client.put( + url=url, + data=json.dumps( + { + "items": self.log_queue, + } + ), + headers=headers, + timeout=60000, + ) + response.raise_for_status() + + if response.status_code >= 300: + verbose_logger.error( + f"Argilla Error: {response.status_code} - {response.text}" + ) + else: + verbose_logger.debug( + "Batch of %s runs successfully created", len(self.log_queue) + ) + except httpx.HTTPStatusError: + verbose_logger.exception("Argilla HTTP Error") + except Exception: + verbose_logger.exception("Argilla Layer Error") diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index c23f9a979..f6a7b2a37 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -59,6 +59,7 @@ from litellm.utils import ( ) from ..integrations.aispend import AISpendLogger +from ..integrations.argilla import ArgillaLogger from ..integrations.athina import AthinaLogger from ..integrations.berrispend import BerriSpendLogger from ..integrations.braintrust_logging import BraintrustLogger @@ -2339,6 +2340,14 @@ def _init_custom_logger_compatible_class( _langsmith_logger = LangsmithLogger() _in_memory_loggers.append(_langsmith_logger) return _langsmith_logger # type: ignore + elif logging_integration == "argilla": + for callback in _in_memory_loggers: + if isinstance(callback, ArgillaLogger): + return callback # type: ignore + + _argilla_logger = ArgillaLogger() + _in_memory_loggers.append(_argilla_logger) + return _argilla_logger # type: ignore elif logging_integration == "literalai": for callback in _in_memory_loggers: if isinstance(callback, LiteralAILogger): @@ -2521,6 +2530,10 @@ def get_custom_logger_compatible_class( for callback in _in_memory_loggers: if isinstance(callback, LangsmithLogger): return callback + elif logging_integration == "argilla": + for callback in _in_memory_loggers: + if isinstance(callback, ArgillaLogger): + return callback elif logging_integration == "literalai": for callback in _in_memory_loggers: if isinstance(callback, LiteralAILogger): diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index f3f38e64d..a2b592ef8 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -163,10 +163,11 @@ class AsyncHTTPHandler: try: if timeout is None: timeout = self.timeout + req = self.client.build_request( "PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore ) - response = await self.client.send(req, stream=stream) + response = await self.client.send(req) response.raise_for_status() return response except (httpx.RemoteProtocolError, httpx.ConnectError): diff --git a/litellm/llms/prompt_templates/common_utils.py b/litellm/llms/prompt_templates/common_utils.py index 213c93764..d7bbf1c62 100644 --- a/litellm/llms/prompt_templates/common_utils.py +++ b/litellm/llms/prompt_templates/common_utils.py @@ -3,11 +3,25 @@ Common utility functions used for translating messages across providers """ import json -from typing import Dict, List +from copy import deepcopy +from typing import Dict, List, Literal, Optional -from litellm.types.llms.openai import AllMessageValues +import litellm +from litellm.types.llms.openai import ( + AllMessageValues, + ChatCompletionAssistantMessage, + ChatCompletionUserMessage, +) from litellm.types.utils import Choices, ModelResponse, StreamingChoices +DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage( + content="Please continue.", role="user" +) + +DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage( + content="Please continue.", role="assistant" +) + def convert_content_list_to_str(message: AllMessageValues) -> str: """ @@ -69,3 +83,145 @@ def get_content_from_model_response(response: ModelResponse) -> str: elif isinstance(choice, StreamingChoices): content += getattr(choice, "delta", {}).get("content", "") or "" return content + + +def detect_first_expected_role( + messages: List[AllMessageValues], +) -> Optional[Literal["user", "assistant"]]: + """ + Detect the first expected role based on the message sequence. + + Rules: + 1. If messages list is empty, assume 'user' starts + 2. If first message is from assistant, expect 'user' next + 3. If first message is from user, expect 'assistant' next + 4. If first message is system, look at the next non-system message + + Returns: + str: Either 'user' or 'assistant' + None: If no 'user' or 'assistant' messages provided + """ + if not messages: + return "user" + + for message in messages: + if message["role"] == "system": + continue + return "user" if message["role"] == "assistant" else "assistant" + + return None + + +def _insert_user_continue_message( + messages: List[AllMessageValues], + user_continue_message: Optional[ChatCompletionUserMessage], + ensure_alternating_roles: bool, +) -> List[AllMessageValues]: + """ + Inserts a user continue message into the messages list. + Handles three cases: + 1. Initial assistant message + 2. Final assistant message + 3. Consecutive assistant messages + + Only inserts messages between consecutive assistant messages, + ignoring all other role types. + """ + if not messages: + return messages + + result_messages = messages.copy() # Don't modify the input list + continue_message = user_continue_message or DEFAULT_USER_CONTINUE_MESSAGE + + # Handle first message if it's an assistant message + if result_messages[0]["role"] == "assistant": + result_messages.insert(0, continue_message) + + # Handle consecutive assistant messages and final message + i = 1 # Start from second message since we handled first message + while i < len(result_messages): + curr_message = result_messages[i] + prev_message = result_messages[i - 1] + + # Only check for consecutive assistant messages + # Ignore all other role types + if curr_message["role"] == "assistant" and prev_message["role"] == "assistant": + result_messages.insert(i, continue_message) + i += 2 # Skip over the message we just inserted + else: + i += 1 + + # Handle final message + if result_messages[-1]["role"] == "assistant" and ensure_alternating_roles: + result_messages.append(continue_message) + + return result_messages + + +def _insert_assistant_continue_message( + messages: List[AllMessageValues], + assistant_continue_message: Optional[ChatCompletionAssistantMessage] = None, + ensure_alternating_roles: bool = True, +) -> List[AllMessageValues]: + """ + Add assistant continuation messages between consecutive user messages. + + Args: + messages: List of message dictionaries + assistant_continue_message: Optional custom assistant message + ensure_alternating_roles: Whether to enforce alternating roles + + Returns: + Modified list of messages with inserted assistant messages + """ + if not ensure_alternating_roles or len(messages) <= 1: + return messages + + # Create a new list to store modified messages + modified_messages: List[AllMessageValues] = [] + + for i, message in enumerate(messages): + modified_messages.append(message) + + # Check if we need to insert an assistant message + if ( + i < len(messages) - 1 # Not the last message + and message.get("role") == "user" # Current is user + and messages[i + 1].get("role") == "user" + ): # Next is user + + # Insert assistant message + continue_message = ( + assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE + ) + modified_messages.append(continue_message) + + return modified_messages + + +def get_completion_messages( + messages: List[AllMessageValues], + assistant_continue_message: Optional[ChatCompletionAssistantMessage], + user_continue_message: Optional[ChatCompletionUserMessage], + ensure_alternating_roles: bool, +) -> List[AllMessageValues]: + """ + Ensures messages alternate between user and assistant roles by adding placeholders + only when there are consecutive messages of the same role. + + 1. ensure 'user' message before 1st 'assistant' message + 2. ensure 'user' message after last 'assistant' message + """ + if not ensure_alternating_roles: + return messages.copy() + + ## INSERT USER CONTINUE MESSAGE + messages = _insert_user_continue_message( + messages, user_continue_message, ensure_alternating_roles + ) + + ## INSERT ASSISTANT CONTINUE MESSAGE + messages = _insert_assistant_continue_message( + messages, assistant_continue_message, ensure_alternating_roles + ) + return messages diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 382c1d148..8dd0f67b6 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -1388,9 +1388,14 @@ def anthropic_messages_pt( for m in user_message_types_block["content"]: if m.get("type", "") == "image_url": m = cast(ChatCompletionImageObject, m) - image_chunk = convert_to_anthropic_image_obj( - openai_image_url=m["image_url"]["url"] # type: ignore - ) + if isinstance(m["image_url"], str): + image_chunk = convert_to_anthropic_image_obj( + openai_image_url=m["image_url"] + ) + else: + image_chunk = convert_to_anthropic_image_obj( + openai_image_url=m["image_url"]["url"] + ) _anthropic_content_element = AnthropicMessagesImageParam( type="image", diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index fc437865a..bee96e0a1 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -887,10 +887,16 @@ class VertexLLM(VertexBase): if "citationMetadata" in candidate: citation_metadata.append(candidate["citationMetadata"]) - if "text" in candidate["content"]["parts"][0]: + if ( + "parts" in candidate["content"] + and "text" in candidate["content"]["parts"][0] + ): content_str = candidate["content"]["parts"][0]["text"] - if "functionCall" in candidate["content"]["parts"][0]: + if ( + "parts" in candidate["content"] + and "functionCall" in candidate["content"]["parts"][0] + ): _function_chunk = ChatCompletionToolCallFunctionChunk( name=candidate["content"]["parts"][0]["functionCall"][ "name" @@ -1358,7 +1364,11 @@ class ModelResponseIterator: if _candidates and len(_candidates) > 0: gemini_chunk = _candidates[0] - if gemini_chunk and "content" in gemini_chunk: + if ( + gemini_chunk + and "content" in gemini_chunk + and "parts" in gemini_chunk["content"] + ): if "text" in gemini_chunk["content"]["parts"][0]: text = gemini_chunk["content"]["parts"][0]["text"] elif "functionCall" in gemini_chunk["content"]["parts"][0]: diff --git a/litellm/main.py b/litellm/main.py index 31e944920..96ec304e5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -109,6 +109,7 @@ from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.predibase import PredibaseChatCompletion +from .llms.prompt_templates.common_utils import get_completion_messages from .llms.prompt_templates.factory import ( custom_prompt, function_call_prompt, @@ -144,7 +145,11 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler im VertexEmbedding, ) from .llms.watsonx import IBMWatsonXAI -from .types.llms.openai import HttpxBinaryResponseContent +from .types.llms.openai import ( + ChatCompletionAssistantMessage, + ChatCompletionUserMessage, + HttpxBinaryResponseContent, +) from .types.utils import ( AdapterCompletionStreamWrapper, ChatCompletionMessageToolCall, @@ -748,6 +753,15 @@ def completion( # type: ignore proxy_server_request = kwargs.get("proxy_server_request", None) fallbacks = kwargs.get("fallbacks", None) headers = kwargs.get("headers", None) or extra_headers + ensure_alternating_roles: Optional[bool] = kwargs.get( + "ensure_alternating_roles", None + ) + user_continue_message: Optional[ChatCompletionUserMessage] = kwargs.get( + "user_continue_message", None + ) + assistant_continue_message: Optional[ChatCompletionAssistantMessage] = kwargs.get( + "assistant_continue_message", None + ) if headers is None: headers = {} @@ -784,7 +798,12 @@ def completion( # type: ignore ### Admin Controls ### no_log = kwargs.get("no-log", False) ### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489 - messages = deepcopy(messages) + messages = get_completion_messages( + messages=messages, + ensure_alternating_roles=ensure_alternating_roles or False, + user_continue_message=user_continue_message, + assistant_continue_message=assistant_continue_message, + ) ######## end of unpacking kwargs ########### openai_params = [ "functions", diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 90c9fc3d9..b82ab339f 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -5,8 +5,8 @@ model_list: api_key: os.environ/OPENAI_API_KEY -assistant_settings: - custom_llm_provider: azure - litellm_params: - api_key: os.environ/AZURE_API_KEY - api_base: os.environ/AZURE_API_BASE +litellm_settings: + callbacks: ["argilla"] + argilla_transformation_object: + user_input: "messages" + llm_output: "response" \ No newline at end of file diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index d34d5d9ef..cc0b42120 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -1,13 +1,17 @@ import re import sys import traceback -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple from fastapi import HTTPException, Request, status from litellm import Router, provider_list from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * +from litellm.types.router import ( + CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS, + ConfigurableClientsideParamsCustomAuth, +) def _get_request_ip_address( @@ -73,8 +77,41 @@ def check_complete_credentials(request_body: dict) -> bool: return False +def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool: + """ + Check if request_body_value matches the regex_str or is equal to param + """ + if re.match(regex_str, request_body_value) or regex_str == request_body_value: + return True + return False + + +def _is_param_allowed( + param: str, + request_body_value: Any, + configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS, +) -> bool: + """ + Check if param is a str or dict and if request_body_value is in the list of allowed values + """ + if configurable_clientside_auth_params is None: + return False + + for item in configurable_clientside_auth_params: + if isinstance(item, str) and param == item: + return True + elif isinstance(item, Dict): + if param == "api_base" and check_regex_or_str_match( + request_body_value=request_body_value, + regex_str=item["api_base"], + ): # assume param is a regex + return True + + return False + + def _allow_model_level_clientside_configurable_parameters( - model: str, param: str, llm_router: Optional[Router] + model: str, param: str, request_body_value: Any, llm_router: Optional[Router] ) -> bool: """ Check if model is allowed to use configurable client-side params @@ -99,10 +136,11 @@ def _allow_model_level_clientside_configurable_parameters( if model_info is None or model_info.configurable_clientside_auth_params is None: return False - if param in model_info.configurable_clientside_auth_params: - return True - - return False + return _is_param_allowed( + param=param, + request_body_value=request_body_value, + configurable_clientside_auth_params=model_info.configurable_clientside_auth_params, + ) def is_request_body_safe( @@ -127,7 +165,10 @@ def is_request_body_safe( return True elif ( _allow_model_level_clientside_configurable_parameters( - model=model, param=param, llm_router=llm_router + model=model, + param=param, + request_body_value=request_body[param], + llm_router=llm_router, ) is True ): diff --git a/litellm/proxy/management_endpoints/sso_helper_utils.py b/litellm/proxy/management_endpoints/sso_helper_utils.py new file mode 100644 index 000000000..14b370c94 --- /dev/null +++ b/litellm/proxy/management_endpoints/sso_helper_utils.py @@ -0,0 +1,24 @@ +from fastapi import HTTPException + +from litellm.proxy._types import LitellmUserRoles + + +def check_is_admin_only_access(ui_access_mode: str) -> bool: + """Checks ui access mode is admin_only""" + return ui_access_mode == "admin_only" + + +def has_admin_ui_access(user_role: str) -> bool: + """ + Check if the user has admin access to the UI. + + Returns: + bool: True if user is 'proxy_admin' or 'proxy_admin_view_only', False otherwise. + """ + + if ( + user_role != LitellmUserRoles.PROXY_ADMIN.value + and user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value + ): + return False + return True diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index d2b0e551a..d515baa96 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -31,6 +31,10 @@ from litellm.proxy.common_utils.admin_ui_utils import ( show_missing_vars_in_env, ) from litellm.proxy.management_endpoints.internal_user_endpoints import new_user +from litellm.proxy.management_endpoints.sso_helper_utils import ( + check_is_admin_only_access, + has_admin_ui_access, +) from litellm.secret_managers.main import str_to_bool if TYPE_CHECKING: @@ -545,17 +549,16 @@ async def auth_callback(request: Request): f"user_role: {user_role}; ui_access_mode: {ui_access_mode}" ) ## CHECK IF ROLE ALLOWED TO USE PROXY ## - if ui_access_mode == "admin_only" and ( - user_role != LitellmUserRoles.PROXY_ADMIN.value - or user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value - ): - verbose_proxy_logger.debug("EXCEPTION RAISED") - raise HTTPException( - status_code=401, - detail={ - "error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}" - }, - ) + is_admin_only_access = check_is_admin_only_access(ui_access_mode) + if is_admin_only_access: + has_access = has_admin_ui_access(user_role) + if not has_access: + raise HTTPException( + status_code=401, + detail={ + "error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}" + }, + ) import jwt diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 75a105339..351cba24f 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1124,6 +1124,7 @@ class PrismaClient: "MonthlyGlobalSpendPerKey", "MonthlyGlobalSpendPerUserPerKey", "Last30dTopEndUsersSpend", + "DailyTagSpend", ] required_view = "LiteLLM_VerificationTokenView" expected_views_str = ", ".join(f"'{view}'" for view in expected_views) diff --git a/litellm/router.py b/litellm/router.py index aa02cfe63..233ef4feb 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -104,6 +104,7 @@ from litellm.types.llms.openai import ( Thread, ) from litellm.types.router import ( + CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS, SPECIAL_MODEL_INFO_PARAMS, VALID_LITELLM_ENVIRONMENTS, AlertingConfig, @@ -4183,7 +4184,7 @@ class Router: total_tpm: Optional[int] = None total_rpm: Optional[int] = None - configurable_clientside_auth_params: Optional[List[str]] = None + configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None for model in self.model_list: is_match = False diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 465752f5b..d8a0942db 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -65,7 +65,7 @@ class HttpxPartType(TypedDict, total=False): class HttpxContentType(TypedDict, total=False): role: Literal["user", "model"] - parts: Required[List[HttpxPartType]] + parts: List[HttpxPartType] class ContentType(TypedDict, total=False): diff --git a/litellm/types/router.py b/litellm/types/router.py index f0737b3ef..6119ca4b7 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -5,10 +5,11 @@ litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc import datetime import enum import uuid -from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import httpx from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import TypedDict from ..exceptions import RateLimitError from .completion import CompletionRequest @@ -16,6 +17,15 @@ from .embedding import EmbeddingRequest from .utils import ModelResponse +class ConfigurableClientsideParamsCustomAuth(TypedDict): + api_base: str + + +CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = Optional[ + List[Union[str, ConfigurableClientsideParamsCustomAuth]] +] + + class ModelConfig(BaseModel): model_name: str litellm_params: Union[CompletionRequest, EmbeddingRequest] @@ -139,7 +149,7 @@ class GenericLiteLLMParams(BaseModel): ) max_retries: Optional[int] = None organization: Optional[str] = None # for openai orgs - configurable_clientside_auth_params: Optional[List[str]] = None + configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None ## UNIFIED PROJECT/REGION ## region_name: Optional[str] = None ## VERTEX AI ## @@ -311,9 +321,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): stream_timeout: Optional[Union[float, str]] max_retries: Optional[int] organization: Optional[Union[List, str]] # for openai orgs - configurable_clientside_auth_params: Optional[ - List[str] - ] # for allowing api base switching on finetuned models + configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS # for allowing api base switching on finetuned models ## DROP PARAMS ## drop_params: Optional[bool] ## UNIFIED PROJECT/REGION ## @@ -496,7 +504,7 @@ class ModelGroupInfo(BaseModel): supports_vision: bool = Field(default=False) supports_function_calling: bool = Field(default=False) supported_openai_params: Optional[List[str]] = Field(default=[]) - configurable_clientside_auth_params: Optional[List[str]] = None + configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None class AssistantsTypedDict(TypedDict): diff --git a/litellm/types/utils.py b/litellm/types/utils.py index cbc0f0274..409c28458 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1246,6 +1246,9 @@ all_litellm_params = [ "user_continue_message", "configurable_clientside_auth_params", "weight", + "ensure_alternating_roles", + "assistant_continue_message", + "user_continue_message", ] diff --git a/tests/local_testing/test_auth_utils.py b/tests/local_testing/test_auth_utils.py new file mode 100644 index 000000000..1118b8a63 --- /dev/null +++ b/tests/local_testing/test_auth_utils.py @@ -0,0 +1,70 @@ +# What is this? +## Tests if proxy/auth/auth_utils.py works as expected + +import sys, os, asyncio, time, random, uuid +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm.proxy.auth.auth_utils import ( + _allow_model_level_clientside_configurable_parameters, +) +from litellm.router import Router + + +@pytest.mark.parametrize( + "allowed_param, input_value, should_return_true", + [ + ("api_base", {"api_base": "http://dummy.com"}, True), + ( + {"api_base": "https://api.openai.com/v1"}, + {"api_base": "https://api.openai.com/v1"}, + True, + ), # should return True + ( + {"api_base": "https://api.openai.com/v1"}, + {"api_base": "https://api.anthropic.com/v1"}, + False, + ), # should return False + ( + {"api_base": "^https://litellm.*direct\.fireworks\.ai/v1$"}, + {"api_base": "https://litellm-dev.direct.fireworks.ai/v1"}, + True, + ), + ( + {"api_base": "^https://litellm.*novice\.fireworks\.ai/v1$"}, + {"api_base": "https://litellm-dev.direct.fireworks.ai/v1"}, + False, + ), + ], +) +def test_configurable_clientside_parameters( + allowed_param, input_value, should_return_true +): + router = Router( + model_list=[ + { + "model_name": "dummy-model", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "dummy-key", + "configurable_clientside_auth_params": [allowed_param], + }, + } + ] + ) + resp = _allow_model_level_clientside_configurable_parameters( + model="dummy-model", + param="api_base", + request_body_value=input_value["api_base"], + llm_router=router, + ) + print(resp) + assert resp == should_return_true diff --git a/tests/local_testing/test_prompt_factory.py b/tests/local_testing/test_prompt_factory.py index faa7e0c33..74e7cefa5 100644 --- a/tests/local_testing/test_prompt_factory.py +++ b/tests/local_testing/test_prompt_factory.py @@ -22,9 +22,13 @@ from litellm.llms.prompt_templates.factory import ( llama_2_chat_pt, prompt_factory, ) +from litellm.llms.prompt_templates.common_utils import ( + get_completion_messages, +) from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import ( _gemini_convert_messages_with_history, ) +from unittest.mock import AsyncMock, MagicMock, patch def test_llama_3_prompt(): @@ -457,3 +461,217 @@ def test_azure_tool_call_invoke_helper(): "function_call": {"name": "get_weather", "arguments": ""}, }, ] + + +@pytest.mark.parametrize( + "messages, expected_messages, user_continue_message, assistant_continue_message", + [ + ( + [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hello! How can I assist you today?"}, + {"role": "user", "content": "What is Databricks?"}, + {"role": "user", "content": "What is Azure?"}, + {"role": "assistant", "content": "I don't know anyything, do you?"}, + ], + [ + {"role": "user", "content": "Hello!"}, + { + "role": "assistant", + "content": "Hello! How can I assist you today?", + }, + {"role": "user", "content": "What is Databricks?"}, + { + "role": "assistant", + "content": "Please continue.", + }, + {"role": "user", "content": "What is Azure?"}, + { + "role": "assistant", + "content": "I don't know anyything, do you?", + }, + { + "role": "user", + "content": "Please continue.", + }, + ], + None, + None, + ), + ( + [ + {"role": "user", "content": "Hello!"}, + ], + [ + {"role": "user", "content": "Hello!"}, + ], + None, + None, + ), + ( + [ + {"role": "user", "content": "Hello!"}, + {"role": "user", "content": "What is Databricks?"}, + ], + [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Please continue."}, + {"role": "user", "content": "What is Databricks?"}, + ], + None, + None, + ), + ( + [ + {"role": "user", "content": "Hello!"}, + {"role": "user", "content": "What is Databricks?"}, + {"role": "user", "content": "What is Azure?"}, + ], + [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Please continue."}, + {"role": "user", "content": "What is Databricks?"}, + { + "role": "assistant", + "content": "Please continue.", + }, + {"role": "user", "content": "What is Azure?"}, + ], + None, + None, + ), + ( + [ + {"role": "user", "content": "Hello!"}, + { + "role": "assistant", + "content": "Hello! How can I assist you today?", + }, + {"role": "user", "content": "What is Databricks?"}, + {"role": "user", "content": "What is Azure?"}, + {"role": "assistant", "content": "I don't know anyything, do you?"}, + {"role": "assistant", "content": "I can't repeat sentences."}, + ], + [ + {"role": "user", "content": "Hello!"}, + { + "role": "assistant", + "content": "Hello! How can I assist you today?", + }, + {"role": "user", "content": "What is Databricks?"}, + { + "role": "assistant", + "content": "Please continue", + }, + {"role": "user", "content": "What is Azure?"}, + { + "role": "assistant", + "content": "I don't know anyything, do you?", + }, + { + "role": "user", + "content": "Ok", + }, + { + "role": "assistant", + "content": "I can't repeat sentences.", + }, + {"role": "user", "content": "Ok"}, + ], + { + "role": "user", + "content": "Ok", + }, + { + "role": "assistant", + "content": "Please continue", + }, + ), + ], +) +def test_ensure_alternating_roles( + messages, expected_messages, user_continue_message, assistant_continue_message +): + + messages = get_completion_messages( + messages=messages, + assistant_continue_message=assistant_continue_message, + user_continue_message=user_continue_message, + ensure_alternating_roles=True, + ) + + print(messages) + + assert messages == expected_messages + + +def test_alternating_roles_e2e(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + import json + + litellm.set_verbose = True + http_handler = HTTPHandler() + + with patch.object(http_handler, "post", new=MagicMock()) as mock_post: + response = litellm.completion( + **{ + "model": "databricks/databricks-meta-llama-3-1-70b-instruct", + "messages": [ + {"role": "user", "content": "Hello!"}, + { + "role": "assistant", + "content": "Hello! How can I assist you today?", + }, + {"role": "user", "content": "What is Databricks?"}, + {"role": "user", "content": "What is Azure?"}, + {"role": "assistant", "content": "I don't know anyything, do you?"}, + {"role": "assistant", "content": "I can't repeat sentences."}, + ], + "user_continue_message": { + "role": "user", + "content": "Ok", + }, + "assistant_continue_message": { + "role": "assistant", + "content": "Please continue", + }, + "ensure_alternating_roles": True, + }, + client=http_handler, + ) + print(f"response: {response}") + assert mock_post.call_args.kwargs["data"] == json.dumps( + { + "model": "databricks-meta-llama-3-1-70b-instruct", + "messages": [ + {"role": "user", "content": "Hello!"}, + { + "role": "assistant", + "content": "Hello! How can I assist you today?", + }, + {"role": "user", "content": "What is Databricks?"}, + { + "role": "assistant", + "content": "Please continue", + }, + {"role": "user", "content": "What is Azure?"}, + { + "role": "assistant", + "content": "I don't know anyything, do you?", + }, + { + "role": "user", + "content": "Ok", + }, + { + "role": "assistant", + "content": "I can't repeat sentences.", + }, + { + "role": "user", + "content": "Ok", + }, + ], + "stream": False, + } + ) diff --git a/tests/local_testing/test_ui_sso_helper_utils.py b/tests/local_testing/test_ui_sso_helper_utils.py new file mode 100644 index 000000000..c72063632 --- /dev/null +++ b/tests/local_testing/test_ui_sso_helper_utils.py @@ -0,0 +1,38 @@ +# What is this? +## This tests the batch update spend logic on the proxy server + + +import asyncio +import os +import random +import sys +import time +import traceback +from datetime import datetime + +from dotenv import load_dotenv +from fastapi import Request + +load_dotenv() + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import logging +from litellm.proxy.management_endpoints.sso_helper_utils import ( + check_is_admin_only_access, + has_admin_ui_access, +) +from litellm.proxy._types import LitellmUserRoles + + +def test_check_is_admin_only_access(): + assert check_is_admin_only_access("admin_only") is True + assert check_is_admin_only_access("user_only") is False + + +def test_has_admin_ui_access(): + assert has_admin_ui_access(LitellmUserRoles.PROXY_ADMIN.value) is True + assert has_admin_ui_access(LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value) is True + assert has_admin_ui_access(LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value) is False