forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (10/17/2024) (#6293)
* fix(ui_sso.py): fix faulty admin only check Fixes https://github.com/BerriAI/litellm/issues/6286 * refactor(sso_helper_utils.py): refactor /sso/callback to use helper utils, covered by unit testing Prevent future regressions * feat(prompt_factory): support 'ensure_alternating_roles' param Closes https://github.com/BerriAI/litellm/issues/6257 * fix(proxy/utils.py): add dailytagspend to expected views * feat(auth_utils.py): support setting regex for clientside auth credentials Fixes https://github.com/BerriAI/litellm/issues/6203 * build(cookbook): add tutorial for mlflow + langchain + litellm proxy tracing * feat(argilla.py): add argilla logging integration Closes https://github.com/BerriAI/litellm/issues/6201 * fix: fix linting errors * fix: fix ruff error * test: fix test * fix: update vertex ai assumption - parts not always guaranteed (#6296) * docs(configs.md): add argila env var to docs
This commit is contained in:
parent
5e381caf75
commit
f252350881
23 changed files with 1388 additions and 43 deletions
312
cookbook/mlflow_langchain_tracing_litellm_proxy.ipynb
vendored
Normal file
312
cookbook/mlflow_langchain_tracing_litellm_proxy.ipynb
vendored
Normal file
|
@ -0,0 +1,312 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Databricks Notebook with MLFlow AutoLogging for LiteLLM Proxy calls\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "5e2812ed-8000-4793-b090-49a31464d810",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install -U -qqqq databricks-agents mlflow langchain==0.3.1 langchain-core==0.3.6 "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "52530b37-1860-4bba-a6c1-723de83bc58f",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install \"langchain-openai<=0.3.1\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "43c6f4b1-e2d5-431c-b1a2-b97df7707d59",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Before logging this chain using the driver notebook, you must comment out this line.\n",
|
||||||
|
"dbutils.library.restartPython() "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "88eb8dd7-16b1-480b-aa70-cd429ef87159",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import mlflow\n",
|
||||||
|
"from operator import itemgetter\n",
|
||||||
|
"from langchain_core.output_parsers import StrOutputParser\n",
|
||||||
|
"from langchain_core.prompts import PromptTemplate\n",
|
||||||
|
"from langchain_core.runnables import RunnableLambda\n",
|
||||||
|
"from langchain_databricks import ChatDatabricks\n",
|
||||||
|
"from langchain_openai import ChatOpenAI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "f0fdca8f-6f6f-407c-ad4a-0d5a2778728e",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import mlflow\n",
|
||||||
|
"mlflow.langchain.autolog()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "2ef67315-e468-4d60-a318-98c2cac75bc4",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# These helper functions parse the `messages` array.\n",
|
||||||
|
"\n",
|
||||||
|
"# Return the string contents of the most recent message from the user\n",
|
||||||
|
"def extract_user_query_string(chat_messages_array):\n",
|
||||||
|
" return chat_messages_array[-1][\"content\"]\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Return the chat history, which is is everything before the last question\n",
|
||||||
|
"def extract_chat_history(chat_messages_array):\n",
|
||||||
|
" return chat_messages_array[:-1]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "17708467-1976-48bd-94a0-8c7895cfae3b",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = ChatOpenAI(\n",
|
||||||
|
" openai_api_base=\"LITELLM_PROXY_BASE_URL\", # e.g.: http://0.0.0.0:4000\n",
|
||||||
|
" model = \"gpt-3.5-turbo\", # LITELLM 'model_name'\n",
|
||||||
|
" temperature=0.1, \n",
|
||||||
|
" api_key=\"LITELLM_PROXY_API_KEY\" # e.g.: \"sk-1234\"\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "a5f2c2af-82f7-470d-b559-47b67fb00cda",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"############\n",
|
||||||
|
"# Prompt Template for generation\n",
|
||||||
|
"############\n",
|
||||||
|
"prompt = PromptTemplate(\n",
|
||||||
|
" template=\"You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}\",\n",
|
||||||
|
" input_variables=[\"question\"],\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"############\n",
|
||||||
|
"# FM for generation\n",
|
||||||
|
"# ChatDatabricks accepts any /llm/v1/chat model serving endpoint\n",
|
||||||
|
"############\n",
|
||||||
|
"model = ChatDatabricks(\n",
|
||||||
|
" endpoint=\"databricks-dbrx-instruct\",\n",
|
||||||
|
" extra_params={\"temperature\": 0.01, \"max_tokens\": 500},\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"############\n",
|
||||||
|
"# Simple chain\n",
|
||||||
|
"############\n",
|
||||||
|
"# The framework requires the chain to return a string value.\n",
|
||||||
|
"chain = (\n",
|
||||||
|
" {\n",
|
||||||
|
" \"question\": itemgetter(\"messages\")\n",
|
||||||
|
" | RunnableLambda(extract_user_query_string),\n",
|
||||||
|
" \"chat_history\": itemgetter(\"messages\") | RunnableLambda(extract_chat_history),\n",
|
||||||
|
" }\n",
|
||||||
|
" | prompt\n",
|
||||||
|
" | model\n",
|
||||||
|
" | StrOutputParser()\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "366edd90-62a1-4d6f-8a65-0211fb24ca02",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'Hello there! I\\'m here to help with your questions. Regarding your query about \"rag,\" it\\'s not something typically associated with a \"hello world\" bot, but I\\'m happy to explain!\\n\\nRAG, or Remote Angular GUI, is a tool that allows you to create and manage Angular applications remotely. It\\'s a way to develop and test Angular components and applications without needing to set up a local development environment. This can be particularly useful for teams working on distributed systems or for developers who prefer to work in a cloud-based environment.\\n\\nI hope this explanation of RAG has been helpful and interesting! If you have any other questions or need further clarification, feel free to ask.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/databricks.mlflow.trace": "\"tr-ea2226413395413ba2cf52cffc523502\"",
|
||||||
|
"text/plain": [
|
||||||
|
"Trace(request_id=tr-ea2226413395413ba2cf52cffc523502)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# This is the same input your chain's REST API will accept.\n",
|
||||||
|
"question = {\n",
|
||||||
|
" \"messages\": [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"user\",\n",
|
||||||
|
" \"content\": \"what is rag?\",\n",
|
||||||
|
" },\n",
|
||||||
|
" ]\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"chain.invoke(question)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+cell": {
|
||||||
|
"cellMetadata": {
|
||||||
|
"byteLimit": 2048000,
|
||||||
|
"rowLimit": 10000
|
||||||
|
},
|
||||||
|
"inputWidgets": {},
|
||||||
|
"nuid": "5d68e37d-0980-4a02-bf8d-885c3853f6c1",
|
||||||
|
"showTitle": false,
|
||||||
|
"title": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"mlflow.models.set_model(model=model)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"application/vnd.databricks.v1+notebook": {
|
||||||
|
"dashboards": [],
|
||||||
|
"environmentMetadata": null,
|
||||||
|
"language": "python",
|
||||||
|
"notebookMetadata": {
|
||||||
|
"pythonIndentUnit": 4
|
||||||
|
},
|
||||||
|
"notebookName": "Untitled Notebook 2024-10-16 19:35:16",
|
||||||
|
"widgets": {}
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
|
@ -873,6 +873,11 @@ router_settings:
|
||||||
| ALLOWED_EMAIL_DOMAINS | List of email domains allowed for access
|
| ALLOWED_EMAIL_DOMAINS | List of email domains allowed for access
|
||||||
| ARIZE_API_KEY | API key for Arize platform integration
|
| ARIZE_API_KEY | API key for Arize platform integration
|
||||||
| ARIZE_SPACE_KEY | Space key for Arize platform
|
| ARIZE_SPACE_KEY | Space key for Arize platform
|
||||||
|
| ARGILLA_BATCH_SIZE | Batch size for Argilla logging
|
||||||
|
| ARGILLA_API_KEY | API key for Argilla platform
|
||||||
|
| ARGILLA_SAMPLING_RATE | Sampling rate for Argilla logging
|
||||||
|
| ARGILLA_DATASET_NAME | Dataset name for Argilla logging
|
||||||
|
| ARGILLA_BASE_URL | Base URL for Argilla service
|
||||||
| ATHINA_API_KEY | API key for Athina service
|
| ATHINA_API_KEY | API key for Athina service
|
||||||
| AUTH_STRATEGY | Strategy used for authentication (e.g., OAuth, API key)
|
| AUTH_STRATEGY | Strategy used for authentication (e.g., OAuth, API key)
|
||||||
| AWS_ACCESS_KEY_ID | Access Key ID for AWS services
|
| AWS_ACCESS_KEY_ID | Access Key ID for AWS services
|
||||||
|
|
|
@ -54,6 +54,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
|
||||||
"langtrace",
|
"langtrace",
|
||||||
"gcs_bucket",
|
"gcs_bucket",
|
||||||
"opik",
|
"opik",
|
||||||
|
"argilla",
|
||||||
]
|
]
|
||||||
_known_custom_logger_compatible_callbacks: List = list(
|
_known_custom_logger_compatible_callbacks: List = list(
|
||||||
get_args(_custom_logger_compatible_callbacks_literal)
|
get_args(_custom_logger_compatible_callbacks_literal)
|
||||||
|
@ -61,6 +62,8 @@ _known_custom_logger_compatible_callbacks: List = list(
|
||||||
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
||||||
langfuse_default_tags: Optional[List[str]] = None
|
langfuse_default_tags: Optional[List[str]] = None
|
||||||
langsmith_batch_size: Optional[int] = None
|
langsmith_batch_size: Optional[int] = None
|
||||||
|
argilla_batch_size: Optional[int] = None
|
||||||
|
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
||||||
_async_input_callback: List[Callable] = (
|
_async_input_callback: List[Callable] = (
|
||||||
[]
|
[]
|
||||||
) # internal variable - async custom callbacks are routed here.
|
) # internal variable - async custom callbacks are routed here.
|
||||||
|
|
|
@ -693,7 +693,11 @@ def completion_cost(
|
||||||
completion_response, RerankResponse
|
completion_response, RerankResponse
|
||||||
):
|
):
|
||||||
meta_obj = completion_response.meta
|
meta_obj = completion_response.meta
|
||||||
billed_units = meta_obj.get("billed_units", {}) or {}
|
if meta_obj is not None:
|
||||||
|
billed_units = meta_obj.get("billed_units", {}) or {}
|
||||||
|
else:
|
||||||
|
billed_units = {}
|
||||||
|
|
||||||
search_units = (
|
search_units = (
|
||||||
billed_units.get("search_units") or 1
|
billed_units.get("search_units") or 1
|
||||||
) # cohere charges per request by default.
|
) # cohere charges per request by default.
|
||||||
|
|
410
litellm/integrations/argilla.py
Normal file
410
litellm/integrations/argilla.py
Normal file
|
@ -0,0 +1,410 @@
|
||||||
|
"""
|
||||||
|
Send logs to Argilla for annotation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import types
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
import dotenv # type: ignore
|
||||||
|
import httpx
|
||||||
|
import requests # type: ignore
|
||||||
|
from pydantic import BaseModel # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
httpxSpecialProvider,
|
||||||
|
)
|
||||||
|
from litellm.llms.prompt_templates.common_utils import get_content_from_model_response
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
|
||||||
|
class LangsmithInputs(BaseModel):
|
||||||
|
model: Optional[str] = None
|
||||||
|
messages: Optional[List[Any]] = None
|
||||||
|
stream: Optional[bool] = None
|
||||||
|
call_type: Optional[str] = None
|
||||||
|
litellm_call_id: Optional[str] = None
|
||||||
|
completion_start_time: Optional[datetime] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
custom_llm_provider: Optional[str] = None
|
||||||
|
input: Optional[List[Any]] = None
|
||||||
|
log_event_type: Optional[str] = None
|
||||||
|
original_response: Optional[Any] = None
|
||||||
|
response_cost: Optional[float] = None
|
||||||
|
|
||||||
|
# LiteLLM Virtual Key specific fields
|
||||||
|
user_api_key: Optional[str] = None
|
||||||
|
user_api_key_user_id: Optional[str] = None
|
||||||
|
user_api_key_team_alias: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ArgillaItem(TypedDict):
|
||||||
|
fields: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class ArgillaPayload(TypedDict):
|
||||||
|
items: List[ArgillaItem]
|
||||||
|
|
||||||
|
|
||||||
|
class ArgillaCredentialsObject(TypedDict):
|
||||||
|
ARGILLA_API_KEY: str
|
||||||
|
ARGILLA_DATASET_NAME: str
|
||||||
|
ARGILLA_BASE_URL: str
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"]
|
||||||
|
|
||||||
|
|
||||||
|
def is_serializable(value):
|
||||||
|
non_serializable_types = (
|
||||||
|
types.CoroutineType,
|
||||||
|
types.FunctionType,
|
||||||
|
types.GeneratorType,
|
||||||
|
BaseModel,
|
||||||
|
)
|
||||||
|
return not isinstance(value, non_serializable_types)
|
||||||
|
|
||||||
|
|
||||||
|
class ArgillaLogger(CustomBatchLogger):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
argilla_api_key: Optional[str] = None,
|
||||||
|
argilla_dataset_name: Optional[str] = None,
|
||||||
|
argilla_base_url: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if litellm.argilla_transformation_object is None:
|
||||||
|
raise Exception(
|
||||||
|
"'litellm.argilla_transformation_object' is required, to log your payload to Argilla."
|
||||||
|
)
|
||||||
|
self.validate_argilla_transformation_object(
|
||||||
|
litellm.argilla_transformation_object
|
||||||
|
)
|
||||||
|
self.argilla_transformation_object = litellm.argilla_transformation_object
|
||||||
|
self.default_credentials = self.get_credentials_from_env(
|
||||||
|
argilla_api_key=argilla_api_key,
|
||||||
|
argilla_dataset_name=argilla_dataset_name,
|
||||||
|
argilla_base_url=argilla_base_url,
|
||||||
|
)
|
||||||
|
self.sampling_rate: float = (
|
||||||
|
float(os.getenv("ARGILLA_SAMPLING_RATE")) # type: ignore
|
||||||
|
if os.getenv("ARGILLA_SAMPLING_RATE") is not None
|
||||||
|
and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() # type: ignore
|
||||||
|
else 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.async_httpx_client = get_async_httpx_client(
|
||||||
|
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||||
|
)
|
||||||
|
_batch_size = (
|
||||||
|
os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size
|
||||||
|
)
|
||||||
|
if _batch_size:
|
||||||
|
self.batch_size = int(_batch_size)
|
||||||
|
asyncio.create_task(self.periodic_flush())
|
||||||
|
self.flush_lock = asyncio.Lock()
|
||||||
|
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||||
|
|
||||||
|
def validate_argilla_transformation_object(
|
||||||
|
self, argilla_transformation_object: Dict[str, Any]
|
||||||
|
):
|
||||||
|
if not isinstance(argilla_transformation_object, dict):
|
||||||
|
raise Exception(
|
||||||
|
"'argilla_transformation_object' must be a dictionary, to log your payload to Argilla."
|
||||||
|
)
|
||||||
|
|
||||||
|
for v in argilla_transformation_object.values():
|
||||||
|
if v not in SUPPORTED_PAYLOAD_FIELDS:
|
||||||
|
raise Exception(
|
||||||
|
f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_credentials_from_env(
|
||||||
|
self,
|
||||||
|
argilla_api_key: Optional[str],
|
||||||
|
argilla_dataset_name: Optional[str],
|
||||||
|
argilla_base_url: Optional[str],
|
||||||
|
) -> ArgillaCredentialsObject:
|
||||||
|
|
||||||
|
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
|
||||||
|
if _credentials_api_key is None:
|
||||||
|
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")
|
||||||
|
|
||||||
|
_credentials_base_url = (
|
||||||
|
argilla_base_url
|
||||||
|
or os.getenv("ARGILLA_BASE_URL")
|
||||||
|
or "http://localhost:6900/"
|
||||||
|
)
|
||||||
|
if _credentials_base_url is None:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid Argilla Base URL given. _credentials_base_url=None."
|
||||||
|
)
|
||||||
|
|
||||||
|
_credentials_dataset_name = (
|
||||||
|
argilla_dataset_name
|
||||||
|
or os.getenv("ARGILLA_DATASET_NAME")
|
||||||
|
or "litellm-completion"
|
||||||
|
)
|
||||||
|
if _credentials_dataset_name is None:
|
||||||
|
raise Exception("Invalid Argilla Dataset give. Value=None.")
|
||||||
|
else:
|
||||||
|
dataset_response = litellm.module_level_client.get(
|
||||||
|
url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}",
|
||||||
|
headers={"X-Argilla-Api-Key": _credentials_api_key},
|
||||||
|
)
|
||||||
|
json_response = dataset_response.json()
|
||||||
|
if (
|
||||||
|
"items" in json_response
|
||||||
|
and isinstance(json_response["items"], list)
|
||||||
|
and len(json_response["items"]) > 0
|
||||||
|
):
|
||||||
|
_credentials_dataset_name = json_response["items"][0]["id"]
|
||||||
|
|
||||||
|
return ArgillaCredentialsObject(
|
||||||
|
ARGILLA_API_KEY=_credentials_api_key,
|
||||||
|
ARGILLA_BASE_URL=_credentials_base_url,
|
||||||
|
ARGILLA_DATASET_NAME=_credentials_dataset_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_chat_messages(
|
||||||
|
self, payload: StandardLoggingPayload
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
payload_messages = payload.get("messages", None)
|
||||||
|
|
||||||
|
if payload_messages is None:
|
||||||
|
raise Exception("No chat messages found in payload.")
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(payload_messages, list)
|
||||||
|
and len(payload_messages) > 0
|
||||||
|
and isinstance(payload_messages[0], dict)
|
||||||
|
):
|
||||||
|
return payload_messages
|
||||||
|
elif isinstance(payload_messages, dict):
|
||||||
|
return [payload_messages]
|
||||||
|
else:
|
||||||
|
raise Exception(f"Invalid chat messages format: {payload_messages}")
|
||||||
|
|
||||||
|
def get_str_response(self, payload: StandardLoggingPayload) -> str:
|
||||||
|
response = payload["response"]
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise Exception("No response found in payload.")
|
||||||
|
|
||||||
|
if isinstance(response, str):
|
||||||
|
return response
|
||||||
|
elif isinstance(response, dict):
|
||||||
|
return (
|
||||||
|
response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Invalid response format: {response}")
|
||||||
|
|
||||||
|
def _prepare_log_data(
|
||||||
|
self, kwargs, response_obj, start_time, end_time
|
||||||
|
) -> ArgillaItem:
|
||||||
|
try:
|
||||||
|
# Ensure everything in the payload is converted to str
|
||||||
|
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
|
"standard_logging_object", None
|
||||||
|
)
|
||||||
|
|
||||||
|
if payload is None:
|
||||||
|
raise Exception("Error logging request payload. Payload=none.")
|
||||||
|
|
||||||
|
argilla_message = self.get_chat_messages(payload)
|
||||||
|
argilla_response = self.get_str_response(payload)
|
||||||
|
argilla_item: ArgillaItem = {"fields": {}}
|
||||||
|
for k, v in self.argilla_transformation_object.items():
|
||||||
|
if v == "messages":
|
||||||
|
argilla_item["fields"][k] = argilla_message
|
||||||
|
elif v == "response":
|
||||||
|
argilla_item["fields"][k] = argilla_response
|
||||||
|
else:
|
||||||
|
argilla_item["fields"][k] = payload.get(v, None)
|
||||||
|
return argilla_item
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _send_batch(self):
|
||||||
|
if not self.log_queue:
|
||||||
|
return
|
||||||
|
|
||||||
|
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
|
||||||
|
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
|
||||||
|
|
||||||
|
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
|
||||||
|
|
||||||
|
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
|
||||||
|
|
||||||
|
headers = {"X-Argilla-Api-Key": argilla_api_key}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
url=url,
|
||||||
|
json=self.log_queue,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code >= 300:
|
||||||
|
verbose_logger.error(
|
||||||
|
f"Argilla Error: {response.status_code} - {response.text}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Batch of {len(self.log_queue)} runs successfully created"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_queue.clear()
|
||||||
|
except Exception:
|
||||||
|
verbose_logger.exception("Argilla Layer Error - Error sending batch.")
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
sampling_rate = (
|
||||||
|
float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
|
||||||
|
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
|
||||||
|
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
|
||||||
|
else 1.0
|
||||||
|
)
|
||||||
|
random_sample = random.random()
|
||||||
|
if random_sample > sampling_rate:
|
||||||
|
verbose_logger.info(
|
||||||
|
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||||
|
sampling_rate, random_sample
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return # Skip logging
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s",
|
||||||
|
kwargs,
|
||||||
|
response_obj,
|
||||||
|
)
|
||||||
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||||
|
self.log_queue.append(data)
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self.log_queue) >= self.batch_size:
|
||||||
|
self._send_batch()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
verbose_logger.exception("Langsmith Layer Error - log_success_event error")
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
sampling_rate = self.sampling_rate
|
||||||
|
random_sample = random.random()
|
||||||
|
if random_sample > sampling_rate:
|
||||||
|
verbose_logger.info(
|
||||||
|
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||||
|
sampling_rate, random_sample
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return # Skip logging
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
|
||||||
|
kwargs,
|
||||||
|
response_obj,
|
||||||
|
)
|
||||||
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||||
|
self.log_queue.append(data)
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Langsmith logging: queue length %s, batch size %s",
|
||||||
|
len(self.log_queue),
|
||||||
|
self.batch_size,
|
||||||
|
)
|
||||||
|
if len(self.log_queue) >= self.batch_size:
|
||||||
|
await self.flush_queue()
|
||||||
|
except Exception:
|
||||||
|
verbose_logger.exception(
|
||||||
|
"Argilla Layer Error - error logging async success event."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
sampling_rate = self.sampling_rate
|
||||||
|
random_sample = random.random()
|
||||||
|
if random_sample > sampling_rate:
|
||||||
|
verbose_logger.info(
|
||||||
|
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||||
|
sampling_rate, random_sample
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return # Skip logging
|
||||||
|
verbose_logger.info("Langsmith Failure Event Logging!")
|
||||||
|
try:
|
||||||
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||||
|
self.log_queue.append(data)
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Langsmith logging: queue length %s, batch size %s",
|
||||||
|
len(self.log_queue),
|
||||||
|
self.batch_size,
|
||||||
|
)
|
||||||
|
if len(self.log_queue) >= self.batch_size:
|
||||||
|
await self.flush_queue()
|
||||||
|
except Exception:
|
||||||
|
verbose_logger.exception(
|
||||||
|
"Langsmith Layer Error - error logging async failure event."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_send_batch(self):
|
||||||
|
"""
|
||||||
|
sends runs to /batch endpoint
|
||||||
|
|
||||||
|
Sends runs from self.log_queue
|
||||||
|
|
||||||
|
Returns: None
|
||||||
|
|
||||||
|
Raises: Does not raise an exception, will only verbose_logger.exception()
|
||||||
|
"""
|
||||||
|
if not self.log_queue:
|
||||||
|
return
|
||||||
|
|
||||||
|
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
|
||||||
|
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
|
||||||
|
|
||||||
|
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
|
||||||
|
|
||||||
|
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
|
||||||
|
|
||||||
|
headers = {"X-Argilla-Api-Key": argilla_api_key}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.async_httpx_client.put(
|
||||||
|
url=url,
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"items": self.log_queue,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
headers=headers,
|
||||||
|
timeout=60000,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
if response.status_code >= 300:
|
||||||
|
verbose_logger.error(
|
||||||
|
f"Argilla Error: {response.status_code} - {response.text}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Batch of %s runs successfully created", len(self.log_queue)
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
verbose_logger.exception("Argilla HTTP Error")
|
||||||
|
except Exception:
|
||||||
|
verbose_logger.exception("Argilla Layer Error")
|
|
@ -59,6 +59,7 @@ from litellm.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..integrations.aispend import AISpendLogger
|
from ..integrations.aispend import AISpendLogger
|
||||||
|
from ..integrations.argilla import ArgillaLogger
|
||||||
from ..integrations.athina import AthinaLogger
|
from ..integrations.athina import AthinaLogger
|
||||||
from ..integrations.berrispend import BerriSpendLogger
|
from ..integrations.berrispend import BerriSpendLogger
|
||||||
from ..integrations.braintrust_logging import BraintrustLogger
|
from ..integrations.braintrust_logging import BraintrustLogger
|
||||||
|
@ -2339,6 +2340,14 @@ def _init_custom_logger_compatible_class(
|
||||||
_langsmith_logger = LangsmithLogger()
|
_langsmith_logger = LangsmithLogger()
|
||||||
_in_memory_loggers.append(_langsmith_logger)
|
_in_memory_loggers.append(_langsmith_logger)
|
||||||
return _langsmith_logger # type: ignore
|
return _langsmith_logger # type: ignore
|
||||||
|
elif logging_integration == "argilla":
|
||||||
|
for callback in _in_memory_loggers:
|
||||||
|
if isinstance(callback, ArgillaLogger):
|
||||||
|
return callback # type: ignore
|
||||||
|
|
||||||
|
_argilla_logger = ArgillaLogger()
|
||||||
|
_in_memory_loggers.append(_argilla_logger)
|
||||||
|
return _argilla_logger # type: ignore
|
||||||
elif logging_integration == "literalai":
|
elif logging_integration == "literalai":
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if isinstance(callback, LiteralAILogger):
|
if isinstance(callback, LiteralAILogger):
|
||||||
|
@ -2521,6 +2530,10 @@ def get_custom_logger_compatible_class(
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if isinstance(callback, LangsmithLogger):
|
if isinstance(callback, LangsmithLogger):
|
||||||
return callback
|
return callback
|
||||||
|
elif logging_integration == "argilla":
|
||||||
|
for callback in _in_memory_loggers:
|
||||||
|
if isinstance(callback, ArgillaLogger):
|
||||||
|
return callback
|
||||||
elif logging_integration == "literalai":
|
elif logging_integration == "literalai":
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if isinstance(callback, LiteralAILogger):
|
if isinstance(callback, LiteralAILogger):
|
||||||
|
|
|
@ -163,10 +163,11 @@ class AsyncHTTPHandler:
|
||||||
try:
|
try:
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = self.timeout
|
timeout = self.timeout
|
||||||
|
|
||||||
req = self.client.build_request(
|
req = self.client.build_request(
|
||||||
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
||||||
)
|
)
|
||||||
response = await self.client.send(req, stream=stream)
|
response = await self.client.send(req)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response
|
return response
|
||||||
except (httpx.RemoteProtocolError, httpx.ConnectError):
|
except (httpx.RemoteProtocolError, httpx.ConnectError):
|
||||||
|
|
|
@ -3,11 +3,25 @@ Common utility functions used for translating messages across providers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict, List
|
from copy import deepcopy
|
||||||
|
from typing import Dict, List, Literal, Optional
|
||||||
|
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
import litellm
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
ChatCompletionAssistantMessage,
|
||||||
|
ChatCompletionUserMessage,
|
||||||
|
)
|
||||||
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
||||||
|
|
||||||
|
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
|
||||||
|
content="Please continue.", role="user"
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage(
|
||||||
|
content="Please continue.", role="assistant"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -69,3 +83,145 @@ def get_content_from_model_response(response: ModelResponse) -> str:
|
||||||
elif isinstance(choice, StreamingChoices):
|
elif isinstance(choice, StreamingChoices):
|
||||||
content += getattr(choice, "delta", {}).get("content", "") or ""
|
content += getattr(choice, "delta", {}).get("content", "") or ""
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def detect_first_expected_role(
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
) -> Optional[Literal["user", "assistant"]]:
|
||||||
|
"""
|
||||||
|
Detect the first expected role based on the message sequence.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
1. If messages list is empty, assume 'user' starts
|
||||||
|
2. If first message is from assistant, expect 'user' next
|
||||||
|
3. If first message is from user, expect 'assistant' next
|
||||||
|
4. If first message is system, look at the next non-system message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Either 'user' or 'assistant'
|
||||||
|
None: If no 'user' or 'assistant' messages provided
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return "user"
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
continue
|
||||||
|
return "user" if message["role"] == "assistant" else "assistant"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _insert_user_continue_message(
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
user_continue_message: Optional[ChatCompletionUserMessage],
|
||||||
|
ensure_alternating_roles: bool,
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
Inserts a user continue message into the messages list.
|
||||||
|
Handles three cases:
|
||||||
|
1. Initial assistant message
|
||||||
|
2. Final assistant message
|
||||||
|
3. Consecutive assistant messages
|
||||||
|
|
||||||
|
Only inserts messages between consecutive assistant messages,
|
||||||
|
ignoring all other role types.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
result_messages = messages.copy() # Don't modify the input list
|
||||||
|
continue_message = user_continue_message or DEFAULT_USER_CONTINUE_MESSAGE
|
||||||
|
|
||||||
|
# Handle first message if it's an assistant message
|
||||||
|
if result_messages[0]["role"] == "assistant":
|
||||||
|
result_messages.insert(0, continue_message)
|
||||||
|
|
||||||
|
# Handle consecutive assistant messages and final message
|
||||||
|
i = 1 # Start from second message since we handled first message
|
||||||
|
while i < len(result_messages):
|
||||||
|
curr_message = result_messages[i]
|
||||||
|
prev_message = result_messages[i - 1]
|
||||||
|
|
||||||
|
# Only check for consecutive assistant messages
|
||||||
|
# Ignore all other role types
|
||||||
|
if curr_message["role"] == "assistant" and prev_message["role"] == "assistant":
|
||||||
|
result_messages.insert(i, continue_message)
|
||||||
|
i += 2 # Skip over the message we just inserted
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Handle final message
|
||||||
|
if result_messages[-1]["role"] == "assistant" and ensure_alternating_roles:
|
||||||
|
result_messages.append(continue_message)
|
||||||
|
|
||||||
|
return result_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _insert_assistant_continue_message(
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
assistant_continue_message: Optional[ChatCompletionAssistantMessage] = None,
|
||||||
|
ensure_alternating_roles: bool = True,
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
Add assistant continuation messages between consecutive user messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
assistant_continue_message: Optional custom assistant message
|
||||||
|
ensure_alternating_roles: Whether to enforce alternating roles
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified list of messages with inserted assistant messages
|
||||||
|
"""
|
||||||
|
if not ensure_alternating_roles or len(messages) <= 1:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Create a new list to store modified messages
|
||||||
|
modified_messages: List[AllMessageValues] = []
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
modified_messages.append(message)
|
||||||
|
|
||||||
|
# Check if we need to insert an assistant message
|
||||||
|
if (
|
||||||
|
i < len(messages) - 1 # Not the last message
|
||||||
|
and message.get("role") == "user" # Current is user
|
||||||
|
and messages[i + 1].get("role") == "user"
|
||||||
|
): # Next is user
|
||||||
|
|
||||||
|
# Insert assistant message
|
||||||
|
continue_message = (
|
||||||
|
assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE
|
||||||
|
)
|
||||||
|
modified_messages.append(continue_message)
|
||||||
|
|
||||||
|
return modified_messages
|
||||||
|
|
||||||
|
|
||||||
|
def get_completion_messages(
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
assistant_continue_message: Optional[ChatCompletionAssistantMessage],
|
||||||
|
user_continue_message: Optional[ChatCompletionUserMessage],
|
||||||
|
ensure_alternating_roles: bool,
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
Ensures messages alternate between user and assistant roles by adding placeholders
|
||||||
|
only when there are consecutive messages of the same role.
|
||||||
|
|
||||||
|
1. ensure 'user' message before 1st 'assistant' message
|
||||||
|
2. ensure 'user' message after last 'assistant' message
|
||||||
|
"""
|
||||||
|
if not ensure_alternating_roles:
|
||||||
|
return messages.copy()
|
||||||
|
|
||||||
|
## INSERT USER CONTINUE MESSAGE
|
||||||
|
messages = _insert_user_continue_message(
|
||||||
|
messages, user_continue_message, ensure_alternating_roles
|
||||||
|
)
|
||||||
|
|
||||||
|
## INSERT ASSISTANT CONTINUE MESSAGE
|
||||||
|
messages = _insert_assistant_continue_message(
|
||||||
|
messages, assistant_continue_message, ensure_alternating_roles
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
|
@ -1388,9 +1388,14 @@ def anthropic_messages_pt(
|
||||||
for m in user_message_types_block["content"]:
|
for m in user_message_types_block["content"]:
|
||||||
if m.get("type", "") == "image_url":
|
if m.get("type", "") == "image_url":
|
||||||
m = cast(ChatCompletionImageObject, m)
|
m = cast(ChatCompletionImageObject, m)
|
||||||
image_chunk = convert_to_anthropic_image_obj(
|
if isinstance(m["image_url"], str):
|
||||||
openai_image_url=m["image_url"]["url"] # type: ignore
|
image_chunk = convert_to_anthropic_image_obj(
|
||||||
)
|
openai_image_url=m["image_url"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_chunk = convert_to_anthropic_image_obj(
|
||||||
|
openai_image_url=m["image_url"]["url"]
|
||||||
|
)
|
||||||
|
|
||||||
_anthropic_content_element = AnthropicMessagesImageParam(
|
_anthropic_content_element = AnthropicMessagesImageParam(
|
||||||
type="image",
|
type="image",
|
||||||
|
|
|
@ -887,10 +887,16 @@ class VertexLLM(VertexBase):
|
||||||
|
|
||||||
if "citationMetadata" in candidate:
|
if "citationMetadata" in candidate:
|
||||||
citation_metadata.append(candidate["citationMetadata"])
|
citation_metadata.append(candidate["citationMetadata"])
|
||||||
if "text" in candidate["content"]["parts"][0]:
|
if (
|
||||||
|
"parts" in candidate["content"]
|
||||||
|
and "text" in candidate["content"]["parts"][0]
|
||||||
|
):
|
||||||
content_str = candidate["content"]["parts"][0]["text"]
|
content_str = candidate["content"]["parts"][0]["text"]
|
||||||
|
|
||||||
if "functionCall" in candidate["content"]["parts"][0]:
|
if (
|
||||||
|
"parts" in candidate["content"]
|
||||||
|
and "functionCall" in candidate["content"]["parts"][0]
|
||||||
|
):
|
||||||
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||||
name=candidate["content"]["parts"][0]["functionCall"][
|
name=candidate["content"]["parts"][0]["functionCall"][
|
||||||
"name"
|
"name"
|
||||||
|
@ -1358,7 +1364,11 @@ class ModelResponseIterator:
|
||||||
if _candidates and len(_candidates) > 0:
|
if _candidates and len(_candidates) > 0:
|
||||||
gemini_chunk = _candidates[0]
|
gemini_chunk = _candidates[0]
|
||||||
|
|
||||||
if gemini_chunk and "content" in gemini_chunk:
|
if (
|
||||||
|
gemini_chunk
|
||||||
|
and "content" in gemini_chunk
|
||||||
|
and "parts" in gemini_chunk["content"]
|
||||||
|
):
|
||||||
if "text" in gemini_chunk["content"]["parts"][0]:
|
if "text" in gemini_chunk["content"]["parts"][0]:
|
||||||
text = gemini_chunk["content"]["parts"][0]["text"]
|
text = gemini_chunk["content"]["parts"][0]["text"]
|
||||||
elif "functionCall" in gemini_chunk["content"]["parts"][0]:
|
elif "functionCall" in gemini_chunk["content"]["parts"][0]:
|
||||||
|
|
|
@ -109,6 +109,7 @@ from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
|
||||||
from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion
|
from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion
|
||||||
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
|
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
|
from .llms.prompt_templates.common_utils import get_completion_messages
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
function_call_prompt,
|
function_call_prompt,
|
||||||
|
@ -144,7 +145,11 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler im
|
||||||
VertexEmbedding,
|
VertexEmbedding,
|
||||||
)
|
)
|
||||||
from .llms.watsonx import IBMWatsonXAI
|
from .llms.watsonx import IBMWatsonXAI
|
||||||
from .types.llms.openai import HttpxBinaryResponseContent
|
from .types.llms.openai import (
|
||||||
|
ChatCompletionAssistantMessage,
|
||||||
|
ChatCompletionUserMessage,
|
||||||
|
HttpxBinaryResponseContent,
|
||||||
|
)
|
||||||
from .types.utils import (
|
from .types.utils import (
|
||||||
AdapterCompletionStreamWrapper,
|
AdapterCompletionStreamWrapper,
|
||||||
ChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCall,
|
||||||
|
@ -748,6 +753,15 @@ def completion( # type: ignore
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
fallbacks = kwargs.get("fallbacks", None)
|
fallbacks = kwargs.get("fallbacks", None)
|
||||||
headers = kwargs.get("headers", None) or extra_headers
|
headers = kwargs.get("headers", None) or extra_headers
|
||||||
|
ensure_alternating_roles: Optional[bool] = kwargs.get(
|
||||||
|
"ensure_alternating_roles", None
|
||||||
|
)
|
||||||
|
user_continue_message: Optional[ChatCompletionUserMessage] = kwargs.get(
|
||||||
|
"user_continue_message", None
|
||||||
|
)
|
||||||
|
assistant_continue_message: Optional[ChatCompletionAssistantMessage] = kwargs.get(
|
||||||
|
"assistant_continue_message", None
|
||||||
|
)
|
||||||
if headers is None:
|
if headers is None:
|
||||||
headers = {}
|
headers = {}
|
||||||
|
|
||||||
|
@ -784,7 +798,12 @@ def completion( # type: ignore
|
||||||
### Admin Controls ###
|
### Admin Controls ###
|
||||||
no_log = kwargs.get("no-log", False)
|
no_log = kwargs.get("no-log", False)
|
||||||
### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489
|
### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489
|
||||||
messages = deepcopy(messages)
|
messages = get_completion_messages(
|
||||||
|
messages=messages,
|
||||||
|
ensure_alternating_roles=ensure_alternating_roles or False,
|
||||||
|
user_continue_message=user_continue_message,
|
||||||
|
assistant_continue_message=assistant_continue_message,
|
||||||
|
)
|
||||||
######## end of unpacking kwargs ###########
|
######## end of unpacking kwargs ###########
|
||||||
openai_params = [
|
openai_params = [
|
||||||
"functions",
|
"functions",
|
||||||
|
|
|
@ -5,8 +5,8 @@ model_list:
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
|
|
||||||
assistant_settings:
|
litellm_settings:
|
||||||
custom_llm_provider: azure
|
callbacks: ["argilla"]
|
||||||
litellm_params:
|
argilla_transformation_object:
|
||||||
api_key: os.environ/AZURE_API_KEY
|
user_input: "messages"
|
||||||
api_base: os.environ/AZURE_API_BASE
|
llm_output: "response"
|
|
@ -1,13 +1,17 @@
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, status
|
from fastapi import HTTPException, Request, status
|
||||||
|
|
||||||
from litellm import Router, provider_list
|
from litellm import Router, provider_list
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
|
from litellm.types.router import (
|
||||||
|
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
||||||
|
ConfigurableClientsideParamsCustomAuth,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_request_ip_address(
|
def _get_request_ip_address(
|
||||||
|
@ -73,8 +77,41 @@ def check_complete_credentials(request_body: dict) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if request_body_value matches the regex_str or is equal to param
|
||||||
|
"""
|
||||||
|
if re.match(regex_str, request_body_value) or regex_str == request_body_value:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_param_allowed(
|
||||||
|
param: str,
|
||||||
|
request_body_value: Any,
|
||||||
|
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if param is a str or dict and if request_body_value is in the list of allowed values
|
||||||
|
"""
|
||||||
|
if configurable_clientside_auth_params is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for item in configurable_clientside_auth_params:
|
||||||
|
if isinstance(item, str) and param == item:
|
||||||
|
return True
|
||||||
|
elif isinstance(item, Dict):
|
||||||
|
if param == "api_base" and check_regex_or_str_match(
|
||||||
|
request_body_value=request_body_value,
|
||||||
|
regex_str=item["api_base"],
|
||||||
|
): # assume param is a regex
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _allow_model_level_clientside_configurable_parameters(
|
def _allow_model_level_clientside_configurable_parameters(
|
||||||
model: str, param: str, llm_router: Optional[Router]
|
model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if model is allowed to use configurable client-side params
|
Check if model is allowed to use configurable client-side params
|
||||||
|
@ -99,10 +136,11 @@ def _allow_model_level_clientside_configurable_parameters(
|
||||||
if model_info is None or model_info.configurable_clientside_auth_params is None:
|
if model_info is None or model_info.configurable_clientside_auth_params is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if param in model_info.configurable_clientside_auth_params:
|
return _is_param_allowed(
|
||||||
return True
|
param=param,
|
||||||
|
request_body_value=request_body_value,
|
||||||
return False
|
configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_request_body_safe(
|
def is_request_body_safe(
|
||||||
|
@ -127,7 +165,10 @@ def is_request_body_safe(
|
||||||
return True
|
return True
|
||||||
elif (
|
elif (
|
||||||
_allow_model_level_clientside_configurable_parameters(
|
_allow_model_level_clientside_configurable_parameters(
|
||||||
model=model, param=param, llm_router=llm_router
|
model=model,
|
||||||
|
param=param,
|
||||||
|
request_body_value=request_body[param],
|
||||||
|
llm_router=llm_router,
|
||||||
)
|
)
|
||||||
is True
|
is True
|
||||||
):
|
):
|
||||||
|
|
24
litellm/proxy/management_endpoints/sso_helper_utils.py
Normal file
24
litellm/proxy/management_endpoints/sso_helper_utils.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from litellm.proxy._types import LitellmUserRoles
|
||||||
|
|
||||||
|
|
||||||
|
def check_is_admin_only_access(ui_access_mode: str) -> bool:
|
||||||
|
"""Checks ui access mode is admin_only"""
|
||||||
|
return ui_access_mode == "admin_only"
|
||||||
|
|
||||||
|
|
||||||
|
def has_admin_ui_access(user_role: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the user has admin access to the UI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if user is 'proxy_admin' or 'proxy_admin_view_only', False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
and user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
|
@ -31,6 +31,10 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
||||||
show_missing_vars_in_env,
|
show_missing_vars_in_env,
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
||||||
|
from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||||
|
check_is_admin_only_access,
|
||||||
|
has_admin_ui_access,
|
||||||
|
)
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -545,17 +549,16 @@ async def auth_callback(request: Request):
|
||||||
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
|
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
|
||||||
)
|
)
|
||||||
## CHECK IF ROLE ALLOWED TO USE PROXY ##
|
## CHECK IF ROLE ALLOWED TO USE PROXY ##
|
||||||
if ui_access_mode == "admin_only" and (
|
is_admin_only_access = check_is_admin_only_access(ui_access_mode)
|
||||||
user_role != LitellmUserRoles.PROXY_ADMIN.value
|
if is_admin_only_access:
|
||||||
or user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
|
has_access = has_admin_ui_access(user_role)
|
||||||
):
|
if not has_access:
|
||||||
verbose_proxy_logger.debug("EXCEPTION RAISED")
|
raise HTTPException(
|
||||||
raise HTTPException(
|
status_code=401,
|
||||||
status_code=401,
|
detail={
|
||||||
detail={
|
"error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}"
|
||||||
"error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}"
|
},
|
||||||
},
|
)
|
||||||
)
|
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
|
|
|
@ -1124,6 +1124,7 @@ class PrismaClient:
|
||||||
"MonthlyGlobalSpendPerKey",
|
"MonthlyGlobalSpendPerKey",
|
||||||
"MonthlyGlobalSpendPerUserPerKey",
|
"MonthlyGlobalSpendPerUserPerKey",
|
||||||
"Last30dTopEndUsersSpend",
|
"Last30dTopEndUsersSpend",
|
||||||
|
"DailyTagSpend",
|
||||||
]
|
]
|
||||||
required_view = "LiteLLM_VerificationTokenView"
|
required_view = "LiteLLM_VerificationTokenView"
|
||||||
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
|
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
|
||||||
|
|
|
@ -104,6 +104,7 @@ from litellm.types.llms.openai import (
|
||||||
Thread,
|
Thread,
|
||||||
)
|
)
|
||||||
from litellm.types.router import (
|
from litellm.types.router import (
|
||||||
|
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
||||||
SPECIAL_MODEL_INFO_PARAMS,
|
SPECIAL_MODEL_INFO_PARAMS,
|
||||||
VALID_LITELLM_ENVIRONMENTS,
|
VALID_LITELLM_ENVIRONMENTS,
|
||||||
AlertingConfig,
|
AlertingConfig,
|
||||||
|
@ -4183,7 +4184,7 @@ class Router:
|
||||||
|
|
||||||
total_tpm: Optional[int] = None
|
total_tpm: Optional[int] = None
|
||||||
total_rpm: Optional[int] = None
|
total_rpm: Optional[int] = None
|
||||||
configurable_clientside_auth_params: Optional[List[str]] = None
|
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
|
||||||
|
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
is_match = False
|
is_match = False
|
||||||
|
|
|
@ -65,7 +65,7 @@ class HttpxPartType(TypedDict, total=False):
|
||||||
|
|
||||||
class HttpxContentType(TypedDict, total=False):
|
class HttpxContentType(TypedDict, total=False):
|
||||||
role: Literal["user", "model"]
|
role: Literal["user", "model"]
|
||||||
parts: Required[List[HttpxPartType]]
|
parts: List[HttpxPartType]
|
||||||
|
|
||||||
|
|
||||||
class ContentType(TypedDict, total=False):
|
class ContentType(TypedDict, total=False):
|
||||||
|
|
|
@ -5,10 +5,11 @@ litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
|
||||||
import datetime
|
import datetime
|
||||||
import enum
|
import enum
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from ..exceptions import RateLimitError
|
from ..exceptions import RateLimitError
|
||||||
from .completion import CompletionRequest
|
from .completion import CompletionRequest
|
||||||
|
@ -16,6 +17,15 @@ from .embedding import EmbeddingRequest
|
||||||
from .utils import ModelResponse
|
from .utils import ModelResponse
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigurableClientsideParamsCustomAuth(TypedDict):
|
||||||
|
api_base: str
|
||||||
|
|
||||||
|
|
||||||
|
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = Optional[
|
||||||
|
List[Union[str, ConfigurableClientsideParamsCustomAuth]]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
model_name: str
|
model_name: str
|
||||||
litellm_params: Union[CompletionRequest, EmbeddingRequest]
|
litellm_params: Union[CompletionRequest, EmbeddingRequest]
|
||||||
|
@ -139,7 +149,7 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
)
|
)
|
||||||
max_retries: Optional[int] = None
|
max_retries: Optional[int] = None
|
||||||
organization: Optional[str] = None # for openai orgs
|
organization: Optional[str] = None # for openai orgs
|
||||||
configurable_clientside_auth_params: Optional[List[str]] = None
|
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
|
||||||
## UNIFIED PROJECT/REGION ##
|
## UNIFIED PROJECT/REGION ##
|
||||||
region_name: Optional[str] = None
|
region_name: Optional[str] = None
|
||||||
## VERTEX AI ##
|
## VERTEX AI ##
|
||||||
|
@ -311,9 +321,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||||
stream_timeout: Optional[Union[float, str]]
|
stream_timeout: Optional[Union[float, str]]
|
||||||
max_retries: Optional[int]
|
max_retries: Optional[int]
|
||||||
organization: Optional[Union[List, str]] # for openai orgs
|
organization: Optional[Union[List, str]] # for openai orgs
|
||||||
configurable_clientside_auth_params: Optional[
|
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS # for allowing api base switching on finetuned models
|
||||||
List[str]
|
|
||||||
] # for allowing api base switching on finetuned models
|
|
||||||
## DROP PARAMS ##
|
## DROP PARAMS ##
|
||||||
drop_params: Optional[bool]
|
drop_params: Optional[bool]
|
||||||
## UNIFIED PROJECT/REGION ##
|
## UNIFIED PROJECT/REGION ##
|
||||||
|
@ -496,7 +504,7 @@ class ModelGroupInfo(BaseModel):
|
||||||
supports_vision: bool = Field(default=False)
|
supports_vision: bool = Field(default=False)
|
||||||
supports_function_calling: bool = Field(default=False)
|
supports_function_calling: bool = Field(default=False)
|
||||||
supported_openai_params: Optional[List[str]] = Field(default=[])
|
supported_openai_params: Optional[List[str]] = Field(default=[])
|
||||||
configurable_clientside_auth_params: Optional[List[str]] = None
|
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
|
||||||
|
|
||||||
|
|
||||||
class AssistantsTypedDict(TypedDict):
|
class AssistantsTypedDict(TypedDict):
|
||||||
|
|
|
@ -1246,6 +1246,9 @@ all_litellm_params = [
|
||||||
"user_continue_message",
|
"user_continue_message",
|
||||||
"configurable_clientside_auth_params",
|
"configurable_clientside_auth_params",
|
||||||
"weight",
|
"weight",
|
||||||
|
"ensure_alternating_roles",
|
||||||
|
"assistant_continue_message",
|
||||||
|
"user_continue_message",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
70
tests/local_testing/test_auth_utils.py
Normal file
70
tests/local_testing/test_auth_utils.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
# What is this?
|
||||||
|
## Tests if proxy/auth/auth_utils.py works as expected
|
||||||
|
|
||||||
|
import sys, os, asyncio, time, random, uuid
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy.auth.auth_utils import (
|
||||||
|
_allow_model_level_clientside_configurable_parameters,
|
||||||
|
)
|
||||||
|
from litellm.router import Router
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"allowed_param, input_value, should_return_true",
|
||||||
|
[
|
||||||
|
("api_base", {"api_base": "http://dummy.com"}, True),
|
||||||
|
(
|
||||||
|
{"api_base": "https://api.openai.com/v1"},
|
||||||
|
{"api_base": "https://api.openai.com/v1"},
|
||||||
|
True,
|
||||||
|
), # should return True
|
||||||
|
(
|
||||||
|
{"api_base": "https://api.openai.com/v1"},
|
||||||
|
{"api_base": "https://api.anthropic.com/v1"},
|
||||||
|
False,
|
||||||
|
), # should return False
|
||||||
|
(
|
||||||
|
{"api_base": "^https://litellm.*direct\.fireworks\.ai/v1$"},
|
||||||
|
{"api_base": "https://litellm-dev.direct.fireworks.ai/v1"},
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"api_base": "^https://litellm.*novice\.fireworks\.ai/v1$"},
|
||||||
|
{"api_base": "https://litellm-dev.direct.fireworks.ai/v1"},
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_configurable_clientside_parameters(
|
||||||
|
allowed_param, input_value, should_return_true
|
||||||
|
):
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "dummy-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "dummy-key",
|
||||||
|
"configurable_clientside_auth_params": [allowed_param],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
resp = _allow_model_level_clientside_configurable_parameters(
|
||||||
|
model="dummy-model",
|
||||||
|
param="api_base",
|
||||||
|
request_body_value=input_value["api_base"],
|
||||||
|
llm_router=router,
|
||||||
|
)
|
||||||
|
print(resp)
|
||||||
|
assert resp == should_return_true
|
|
@ -22,9 +22,13 @@ from litellm.llms.prompt_templates.factory import (
|
||||||
llama_2_chat_pt,
|
llama_2_chat_pt,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.prompt_templates.common_utils import (
|
||||||
|
get_completion_messages,
|
||||||
|
)
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
|
||||||
_gemini_convert_messages_with_history,
|
_gemini_convert_messages_with_history,
|
||||||
)
|
)
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
def test_llama_3_prompt():
|
def test_llama_3_prompt():
|
||||||
|
@ -457,3 +461,217 @@ def test_azure_tool_call_invoke_helper():
|
||||||
"function_call": {"name": "get_weather", "arguments": ""},
|
"function_call": {"name": "get_weather", "arguments": ""},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"messages, expected_messages, user_continue_message, assistant_continue_message",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||||
|
{"role": "user", "content": "What is Databricks?"},
|
||||||
|
{"role": "user", "content": "What is Azure?"},
|
||||||
|
{"role": "assistant", "content": "I don't know anyything, do you?"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! How can I assist you today?",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What is Databricks?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Please continue.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What is Azure?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I don't know anyything, do you?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Please continue.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{"role": "user", "content": "What is Databricks?"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{"role": "assistant", "content": "Please continue."},
|
||||||
|
{"role": "user", "content": "What is Databricks?"},
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{"role": "user", "content": "What is Databricks?"},
|
||||||
|
{"role": "user", "content": "What is Azure?"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{"role": "assistant", "content": "Please continue."},
|
||||||
|
{"role": "user", "content": "What is Databricks?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Please continue.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What is Azure?"},
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! How can I assist you today?",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What is Databricks?"},
|
||||||
|
{"role": "user", "content": "What is Azure?"},
|
||||||
|
{"role": "assistant", "content": "I don't know anyything, do you?"},
|
||||||
|
{"role": "assistant", "content": "I can't repeat sentences."},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! How can I assist you today?",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What is Databricks?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Please continue",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What is Azure?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I don't know anyything, do you?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Ok",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I can't repeat sentences.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Ok"},
|
||||||
|
],
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Ok",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Please continue",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_ensure_alternating_roles(
|
||||||
|
messages, expected_messages, user_continue_message, assistant_continue_message
|
||||||
|
):
|
||||||
|
|
||||||
|
messages = get_completion_messages(
|
||||||
|
messages=messages,
|
||||||
|
assistant_continue_message=assistant_continue_message,
|
||||||
|
user_continue_message=user_continue_message,
|
||||||
|
ensure_alternating_roles=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(messages)
|
||||||
|
|
||||||
|
assert messages == expected_messages
|
||||||
|
|
||||||
|
|
||||||
|
def test_alternating_roles_e2e():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
import json
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
http_handler = HTTPHandler()
|
||||||
|
|
||||||
|
with patch.object(http_handler, "post", new=MagicMock()) as mock_post:
|
||||||
|
response = litellm.completion(
|
||||||
|
**{
|
||||||
|
"model": "databricks/databricks-meta-llama-3-1-70b-instruct",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! How can I assist you today?",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What is Databricks?"},
|
||||||
|
{"role": "user", "content": "What is Azure?"},
|
||||||
|
{"role": "assistant", "content": "I don't know anyything, do you?"},
|
||||||
|
{"role": "assistant", "content": "I can't repeat sentences."},
|
||||||
|
],
|
||||||
|
"user_continue_message": {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Ok",
|
||||||
|
},
|
||||||
|
"assistant_continue_message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Please continue",
|
||||||
|
},
|
||||||
|
"ensure_alternating_roles": True,
|
||||||
|
},
|
||||||
|
client=http_handler,
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert mock_post.call_args.kwargs["data"] == json.dumps(
|
||||||
|
{
|
||||||
|
"model": "databricks-meta-llama-3-1-70b-instruct",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! How can I assist you today?",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What is Databricks?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Please continue",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What is Azure?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I don't know anyything, do you?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Ok",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I can't repeat sentences.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Ok",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
38
tests/local_testing/test_ui_sso_helper_utils.py
Normal file
38
tests/local_testing/test_ui_sso_helper_utils.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
# What is this?
|
||||||
|
## This tests the batch update spend logic on the proxy server
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||||
|
check_is_admin_only_access,
|
||||||
|
has_admin_ui_access,
|
||||||
|
)
|
||||||
|
from litellm.proxy._types import LitellmUserRoles
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_is_admin_only_access():
|
||||||
|
assert check_is_admin_only_access("admin_only") is True
|
||||||
|
assert check_is_admin_only_access("user_only") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_admin_ui_access():
|
||||||
|
assert has_admin_ui_access(LitellmUserRoles.PROXY_ADMIN.value) is True
|
||||||
|
assert has_admin_ui_access(LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value) is True
|
||||||
|
assert has_admin_ui_access(LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value) is False
|
Loading…
Add table
Add a link
Reference in a new issue