forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (10/17/2024) (#6293)
* fix(ui_sso.py): fix faulty admin only check Fixes https://github.com/BerriAI/litellm/issues/6286 * refactor(sso_helper_utils.py): refactor /sso/callback to use helper utils, covered by unit testing Prevent future regressions * feat(prompt_factory): support 'ensure_alternating_roles' param Closes https://github.com/BerriAI/litellm/issues/6257 * fix(proxy/utils.py): add dailytagspend to expected views * feat(auth_utils.py): support setting regex for clientside auth credentials Fixes https://github.com/BerriAI/litellm/issues/6203 * build(cookbook): add tutorial for mlflow + langchain + litellm proxy tracing * feat(argilla.py): add argilla logging integration Closes https://github.com/BerriAI/litellm/issues/6201 * fix: fix linting errors * fix: fix ruff error * test: fix test * fix: update vertex ai assumption - parts not always guaranteed (#6296) * docs(configs.md): add argila env var to docs
This commit is contained in:
parent
5e381caf75
commit
f252350881
23 changed files with 1388 additions and 43 deletions
312
cookbook/mlflow_langchain_tracing_litellm_proxy.ipynb
vendored
Normal file
312
cookbook/mlflow_langchain_tracing_litellm_proxy.ipynb
vendored
Normal file
|
@ -0,0 +1,312 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Databricks Notebook with MLFlow AutoLogging for LiteLLM Proxy calls\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "5e2812ed-8000-4793-b090-49a31464d810",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -U -qqqq databricks-agents mlflow langchain==0.3.1 langchain-core==0.3.6 "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "52530b37-1860-4bba-a6c1-723de83bc58f",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install \"langchain-openai<=0.3.1\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "43c6f4b1-e2d5-431c-b1a2-b97df7707d59",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Before logging this chain using the driver notebook, you must comment out this line.\n",
|
||||
"dbutils.library.restartPython() "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "88eb8dd7-16b1-480b-aa70-cd429ef87159",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import mlflow\n",
|
||||
"from operator import itemgetter\n",
|
||||
"from langchain_core.output_parsers import StrOutputParser\n",
|
||||
"from langchain_core.prompts import PromptTemplate\n",
|
||||
"from langchain_core.runnables import RunnableLambda\n",
|
||||
"from langchain_databricks import ChatDatabricks\n",
|
||||
"from langchain_openai import ChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "f0fdca8f-6f6f-407c-ad4a-0d5a2778728e",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import mlflow\n",
|
||||
"mlflow.langchain.autolog()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "2ef67315-e468-4d60-a318-98c2cac75bc4",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# These helper functions parse the `messages` array.\n",
|
||||
"\n",
|
||||
"# Return the string contents of the most recent message from the user\n",
|
||||
"def extract_user_query_string(chat_messages_array):\n",
|
||||
" return chat_messages_array[-1][\"content\"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Return the chat history, which is is everything before the last question\n",
|
||||
"def extract_chat_history(chat_messages_array):\n",
|
||||
" return chat_messages_array[:-1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "17708467-1976-48bd-94a0-8c7895cfae3b",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = ChatOpenAI(\n",
|
||||
" openai_api_base=\"LITELLM_PROXY_BASE_URL\", # e.g.: http://0.0.0.0:4000\n",
|
||||
" model = \"gpt-3.5-turbo\", # LITELLM 'model_name'\n",
|
||||
" temperature=0.1, \n",
|
||||
" api_key=\"LITELLM_PROXY_API_KEY\" # e.g.: \"sk-1234\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "a5f2c2af-82f7-470d-b559-47b67fb00cda",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"############\n",
|
||||
"# Prompt Template for generation\n",
|
||||
"############\n",
|
||||
"prompt = PromptTemplate(\n",
|
||||
" template=\"You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}\",\n",
|
||||
" input_variables=[\"question\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"############\n",
|
||||
"# FM for generation\n",
|
||||
"# ChatDatabricks accepts any /llm/v1/chat model serving endpoint\n",
|
||||
"############\n",
|
||||
"model = ChatDatabricks(\n",
|
||||
" endpoint=\"databricks-dbrx-instruct\",\n",
|
||||
" extra_params={\"temperature\": 0.01, \"max_tokens\": 500},\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"############\n",
|
||||
"# Simple chain\n",
|
||||
"############\n",
|
||||
"# The framework requires the chain to return a string value.\n",
|
||||
"chain = (\n",
|
||||
" {\n",
|
||||
" \"question\": itemgetter(\"messages\")\n",
|
||||
" | RunnableLambda(extract_user_query_string),\n",
|
||||
" \"chat_history\": itemgetter(\"messages\") | RunnableLambda(extract_chat_history),\n",
|
||||
" }\n",
|
||||
" | prompt\n",
|
||||
" | model\n",
|
||||
" | StrOutputParser()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "366edd90-62a1-4d6f-8a65-0211fb24ca02",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Hello there! I\\'m here to help with your questions. Regarding your query about \"rag,\" it\\'s not something typically associated with a \"hello world\" bot, but I\\'m happy to explain!\\n\\nRAG, or Remote Angular GUI, is a tool that allows you to create and manage Angular applications remotely. It\\'s a way to develop and test Angular components and applications without needing to set up a local development environment. This can be particularly useful for teams working on distributed systems or for developers who prefer to work in a cloud-based environment.\\n\\nI hope this explanation of RAG has been helpful and interesting! If you have any other questions or need further clarification, feel free to ask.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/databricks.mlflow.trace": "\"tr-ea2226413395413ba2cf52cffc523502\"",
|
||||
"text/plain": [
|
||||
"Trace(request_id=tr-ea2226413395413ba2cf52cffc523502)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# This is the same input your chain's REST API will accept.\n",
|
||||
"question = {\n",
|
||||
" \"messages\": [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"what is rag?\",\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"chain.invoke(question)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+cell": {
|
||||
"cellMetadata": {
|
||||
"byteLimit": 2048000,
|
||||
"rowLimit": 10000
|
||||
},
|
||||
"inputWidgets": {},
|
||||
"nuid": "5d68e37d-0980-4a02-bf8d-885c3853f6c1",
|
||||
"showTitle": false,
|
||||
"title": ""
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mlflow.models.set_model(model=model)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"application/vnd.databricks.v1+notebook": {
|
||||
"dashboards": [],
|
||||
"environmentMetadata": null,
|
||||
"language": "python",
|
||||
"notebookMetadata": {
|
||||
"pythonIndentUnit": 4
|
||||
},
|
||||
"notebookName": "Untitled Notebook 2024-10-16 19:35:16",
|
||||
"widgets": {}
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
|
@ -873,6 +873,11 @@ router_settings:
|
|||
| ALLOWED_EMAIL_DOMAINS | List of email domains allowed for access
|
||||
| ARIZE_API_KEY | API key for Arize platform integration
|
||||
| ARIZE_SPACE_KEY | Space key for Arize platform
|
||||
| ARGILLA_BATCH_SIZE | Batch size for Argilla logging
|
||||
| ARGILLA_API_KEY | API key for Argilla platform
|
||||
| ARGILLA_SAMPLING_RATE | Sampling rate for Argilla logging
|
||||
| ARGILLA_DATASET_NAME | Dataset name for Argilla logging
|
||||
| ARGILLA_BASE_URL | Base URL for Argilla service
|
||||
| ATHINA_API_KEY | API key for Athina service
|
||||
| AUTH_STRATEGY | Strategy used for authentication (e.g., OAuth, API key)
|
||||
| AWS_ACCESS_KEY_ID | Access Key ID for AWS services
|
||||
|
|
|
@ -54,6 +54,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
|
|||
"langtrace",
|
||||
"gcs_bucket",
|
||||
"opik",
|
||||
"argilla",
|
||||
]
|
||||
_known_custom_logger_compatible_callbacks: List = list(
|
||||
get_args(_custom_logger_compatible_callbacks_literal)
|
||||
|
@ -61,6 +62,8 @@ _known_custom_logger_compatible_callbacks: List = list(
|
|||
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
||||
langfuse_default_tags: Optional[List[str]] = None
|
||||
langsmith_batch_size: Optional[int] = None
|
||||
argilla_batch_size: Optional[int] = None
|
||||
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
||||
_async_input_callback: List[Callable] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
|
|
|
@ -693,7 +693,11 @@ def completion_cost(
|
|||
completion_response, RerankResponse
|
||||
):
|
||||
meta_obj = completion_response.meta
|
||||
billed_units = meta_obj.get("billed_units", {}) or {}
|
||||
if meta_obj is not None:
|
||||
billed_units = meta_obj.get("billed_units", {}) or {}
|
||||
else:
|
||||
billed_units = {}
|
||||
|
||||
search_units = (
|
||||
billed_units.get("search_units") or 1
|
||||
) # cohere charges per request by default.
|
||||
|
|
410
litellm/integrations/argilla.py
Normal file
410
litellm/integrations/argilla.py
Normal file
|
@ -0,0 +1,410 @@
|
|||
"""
|
||||
Send logs to Argilla for annotation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||
|
||||
import dotenv # type: ignore
|
||||
import httpx
|
||||
import requests # type: ignore
|
||||
from pydantic import BaseModel # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.llms.prompt_templates.common_utils import get_content_from_model_response
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class LangsmithInputs(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: Optional[List[Any]] = None
|
||||
stream: Optional[bool] = None
|
||||
call_type: Optional[str] = None
|
||||
litellm_call_id: Optional[str] = None
|
||||
completion_start_time: Optional[datetime] = None
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
custom_llm_provider: Optional[str] = None
|
||||
input: Optional[List[Any]] = None
|
||||
log_event_type: Optional[str] = None
|
||||
original_response: Optional[Any] = None
|
||||
response_cost: Optional[float] = None
|
||||
|
||||
# LiteLLM Virtual Key specific fields
|
||||
user_api_key: Optional[str] = None
|
||||
user_api_key_user_id: Optional[str] = None
|
||||
user_api_key_team_alias: Optional[str] = None
|
||||
|
||||
|
||||
class ArgillaItem(TypedDict):
|
||||
fields: Dict[str, Any]
|
||||
|
||||
|
||||
class ArgillaPayload(TypedDict):
|
||||
items: List[ArgillaItem]
|
||||
|
||||
|
||||
class ArgillaCredentialsObject(TypedDict):
|
||||
ARGILLA_API_KEY: str
|
||||
ARGILLA_DATASET_NAME: str
|
||||
ARGILLA_BASE_URL: str
|
||||
|
||||
|
||||
SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"]
|
||||
|
||||
|
||||
def is_serializable(value):
|
||||
non_serializable_types = (
|
||||
types.CoroutineType,
|
||||
types.FunctionType,
|
||||
types.GeneratorType,
|
||||
BaseModel,
|
||||
)
|
||||
return not isinstance(value, non_serializable_types)
|
||||
|
||||
|
||||
class ArgillaLogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
argilla_api_key: Optional[str] = None,
|
||||
argilla_dataset_name: Optional[str] = None,
|
||||
argilla_base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if litellm.argilla_transformation_object is None:
|
||||
raise Exception(
|
||||
"'litellm.argilla_transformation_object' is required, to log your payload to Argilla."
|
||||
)
|
||||
self.validate_argilla_transformation_object(
|
||||
litellm.argilla_transformation_object
|
||||
)
|
||||
self.argilla_transformation_object = litellm.argilla_transformation_object
|
||||
self.default_credentials = self.get_credentials_from_env(
|
||||
argilla_api_key=argilla_api_key,
|
||||
argilla_dataset_name=argilla_dataset_name,
|
||||
argilla_base_url=argilla_base_url,
|
||||
)
|
||||
self.sampling_rate: float = (
|
||||
float(os.getenv("ARGILLA_SAMPLING_RATE")) # type: ignore
|
||||
if os.getenv("ARGILLA_SAMPLING_RATE") is not None
|
||||
and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() # type: ignore
|
||||
else 1.0
|
||||
)
|
||||
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
_batch_size = (
|
||||
os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size
|
||||
)
|
||||
if _batch_size:
|
||||
self.batch_size = int(_batch_size)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
|
||||
def validate_argilla_transformation_object(
|
||||
self, argilla_transformation_object: Dict[str, Any]
|
||||
):
|
||||
if not isinstance(argilla_transformation_object, dict):
|
||||
raise Exception(
|
||||
"'argilla_transformation_object' must be a dictionary, to log your payload to Argilla."
|
||||
)
|
||||
|
||||
for v in argilla_transformation_object.values():
|
||||
if v not in SUPPORTED_PAYLOAD_FIELDS:
|
||||
raise Exception(
|
||||
f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key."
|
||||
)
|
||||
|
||||
def get_credentials_from_env(
|
||||
self,
|
||||
argilla_api_key: Optional[str],
|
||||
argilla_dataset_name: Optional[str],
|
||||
argilla_base_url: Optional[str],
|
||||
) -> ArgillaCredentialsObject:
|
||||
|
||||
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
|
||||
if _credentials_api_key is None:
|
||||
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")
|
||||
|
||||
_credentials_base_url = (
|
||||
argilla_base_url
|
||||
or os.getenv("ARGILLA_BASE_URL")
|
||||
or "http://localhost:6900/"
|
||||
)
|
||||
if _credentials_base_url is None:
|
||||
raise Exception(
|
||||
"Invalid Argilla Base URL given. _credentials_base_url=None."
|
||||
)
|
||||
|
||||
_credentials_dataset_name = (
|
||||
argilla_dataset_name
|
||||
or os.getenv("ARGILLA_DATASET_NAME")
|
||||
or "litellm-completion"
|
||||
)
|
||||
if _credentials_dataset_name is None:
|
||||
raise Exception("Invalid Argilla Dataset give. Value=None.")
|
||||
else:
|
||||
dataset_response = litellm.module_level_client.get(
|
||||
url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}",
|
||||
headers={"X-Argilla-Api-Key": _credentials_api_key},
|
||||
)
|
||||
json_response = dataset_response.json()
|
||||
if (
|
||||
"items" in json_response
|
||||
and isinstance(json_response["items"], list)
|
||||
and len(json_response["items"]) > 0
|
||||
):
|
||||
_credentials_dataset_name = json_response["items"][0]["id"]
|
||||
|
||||
return ArgillaCredentialsObject(
|
||||
ARGILLA_API_KEY=_credentials_api_key,
|
||||
ARGILLA_BASE_URL=_credentials_base_url,
|
||||
ARGILLA_DATASET_NAME=_credentials_dataset_name,
|
||||
)
|
||||
|
||||
def get_chat_messages(
|
||||
self, payload: StandardLoggingPayload
|
||||
) -> List[Dict[str, Any]]:
|
||||
payload_messages = payload.get("messages", None)
|
||||
|
||||
if payload_messages is None:
|
||||
raise Exception("No chat messages found in payload.")
|
||||
|
||||
if (
|
||||
isinstance(payload_messages, list)
|
||||
and len(payload_messages) > 0
|
||||
and isinstance(payload_messages[0], dict)
|
||||
):
|
||||
return payload_messages
|
||||
elif isinstance(payload_messages, dict):
|
||||
return [payload_messages]
|
||||
else:
|
||||
raise Exception(f"Invalid chat messages format: {payload_messages}")
|
||||
|
||||
def get_str_response(self, payload: StandardLoggingPayload) -> str:
|
||||
response = payload["response"]
|
||||
|
||||
if response is None:
|
||||
raise Exception("No response found in payload.")
|
||||
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
elif isinstance(response, dict):
|
||||
return (
|
||||
response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Invalid response format: {response}")
|
||||
|
||||
def _prepare_log_data(
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
) -> ArgillaItem:
|
||||
try:
|
||||
# Ensure everything in the payload is converted to str
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if payload is None:
|
||||
raise Exception("Error logging request payload. Payload=none.")
|
||||
|
||||
argilla_message = self.get_chat_messages(payload)
|
||||
argilla_response = self.get_str_response(payload)
|
||||
argilla_item: ArgillaItem = {"fields": {}}
|
||||
for k, v in self.argilla_transformation_object.items():
|
||||
if v == "messages":
|
||||
argilla_item["fields"][k] = argilla_message
|
||||
elif v == "response":
|
||||
argilla_item["fields"][k] = argilla_response
|
||||
else:
|
||||
argilla_item["fields"][k] = payload.get(v, None)
|
||||
return argilla_item
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _send_batch(self):
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
|
||||
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
|
||||
|
||||
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
|
||||
|
||||
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
|
||||
|
||||
headers = {"X-Argilla-Api-Key": argilla_api_key}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url=url,
|
||||
json=self.log_queue,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code >= 300:
|
||||
verbose_logger.error(
|
||||
f"Argilla Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Batch of {len(self.log_queue)} runs successfully created"
|
||||
)
|
||||
|
||||
self.log_queue.clear()
|
||||
except Exception:
|
||||
verbose_logger.exception("Argilla Layer Error - Error sending batch.")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
sampling_rate = (
|
||||
float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
|
||||
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
|
||||
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
|
||||
else 1.0
|
||||
)
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.debug(
|
||||
"Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
|
||||
)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
self._send_batch()
|
||||
|
||||
except Exception:
|
||||
verbose_logger.exception("Langsmith Layer Error - log_success_event error")
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
sampling_rate = self.sampling_rate
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.debug(
|
||||
"Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
|
||||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
"Langsmith logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Argilla Layer Error - error logging async success event."
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
sampling_rate = self.sampling_rate
|
||||
random_sample = random.random()
|
||||
if random_sample > sampling_rate:
|
||||
verbose_logger.info(
|
||||
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
|
||||
sampling_rate, random_sample
|
||||
)
|
||||
)
|
||||
return # Skip logging
|
||||
verbose_logger.info("Langsmith Failure Event Logging!")
|
||||
try:
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
verbose_logger.debug(
|
||||
"Langsmith logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Langsmith Layer Error - error logging async failure event."
|
||||
)
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
sends runs to /batch endpoint
|
||||
|
||||
Sends runs from self.log_queue
|
||||
|
||||
Returns: None
|
||||
|
||||
Raises: Does not raise an exception, will only verbose_logger.exception()
|
||||
"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"]
|
||||
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"]
|
||||
|
||||
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk"
|
||||
|
||||
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"]
|
||||
|
||||
headers = {"X-Argilla-Api-Key": argilla_api_key}
|
||||
|
||||
try:
|
||||
response = await self.async_httpx_client.put(
|
||||
url=url,
|
||||
data=json.dumps(
|
||||
{
|
||||
"items": self.log_queue,
|
||||
}
|
||||
),
|
||||
headers=headers,
|
||||
timeout=60000,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code >= 300:
|
||||
verbose_logger.error(
|
||||
f"Argilla Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
"Batch of %s runs successfully created", len(self.log_queue)
|
||||
)
|
||||
except httpx.HTTPStatusError:
|
||||
verbose_logger.exception("Argilla HTTP Error")
|
||||
except Exception:
|
||||
verbose_logger.exception("Argilla Layer Error")
|
|
@ -59,6 +59,7 @@ from litellm.utils import (
|
|||
)
|
||||
|
||||
from ..integrations.aispend import AISpendLogger
|
||||
from ..integrations.argilla import ArgillaLogger
|
||||
from ..integrations.athina import AthinaLogger
|
||||
from ..integrations.berrispend import BerriSpendLogger
|
||||
from ..integrations.braintrust_logging import BraintrustLogger
|
||||
|
@ -2339,6 +2340,14 @@ def _init_custom_logger_compatible_class(
|
|||
_langsmith_logger = LangsmithLogger()
|
||||
_in_memory_loggers.append(_langsmith_logger)
|
||||
return _langsmith_logger # type: ignore
|
||||
elif logging_integration == "argilla":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, ArgillaLogger):
|
||||
return callback # type: ignore
|
||||
|
||||
_argilla_logger = ArgillaLogger()
|
||||
_in_memory_loggers.append(_argilla_logger)
|
||||
return _argilla_logger # type: ignore
|
||||
elif logging_integration == "literalai":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, LiteralAILogger):
|
||||
|
@ -2521,6 +2530,10 @@ def get_custom_logger_compatible_class(
|
|||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, LangsmithLogger):
|
||||
return callback
|
||||
elif logging_integration == "argilla":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, ArgillaLogger):
|
||||
return callback
|
||||
elif logging_integration == "literalai":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, LiteralAILogger):
|
||||
|
|
|
@ -163,10 +163,11 @@ class AsyncHTTPHandler:
|
|||
try:
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
req = self.client.build_request(
|
||||
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
||||
)
|
||||
response = await self.client.send(req, stream=stream)
|
||||
response = await self.client.send(req)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except (httpx.RemoteProtocolError, httpx.ConnectError):
|
||||
|
|
|
@ -3,11 +3,25 @@ Common utility functions used for translating messages across providers
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, List
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
import litellm
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionUserMessage,
|
||||
)
|
||||
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
||||
|
||||
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
|
||||
content="Please continue.", role="user"
|
||||
)
|
||||
|
||||
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage(
|
||||
content="Please continue.", role="assistant"
|
||||
)
|
||||
|
||||
|
||||
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
||||
"""
|
||||
|
@ -69,3 +83,145 @@ def get_content_from_model_response(response: ModelResponse) -> str:
|
|||
elif isinstance(choice, StreamingChoices):
|
||||
content += getattr(choice, "delta", {}).get("content", "") or ""
|
||||
return content
|
||||
|
||||
|
||||
def detect_first_expected_role(
|
||||
messages: List[AllMessageValues],
|
||||
) -> Optional[Literal["user", "assistant"]]:
|
||||
"""
|
||||
Detect the first expected role based on the message sequence.
|
||||
|
||||
Rules:
|
||||
1. If messages list is empty, assume 'user' starts
|
||||
2. If first message is from assistant, expect 'user' next
|
||||
3. If first message is from user, expect 'assistant' next
|
||||
4. If first message is system, look at the next non-system message
|
||||
|
||||
Returns:
|
||||
str: Either 'user' or 'assistant'
|
||||
None: If no 'user' or 'assistant' messages provided
|
||||
"""
|
||||
if not messages:
|
||||
return "user"
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
continue
|
||||
return "user" if message["role"] == "assistant" else "assistant"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _insert_user_continue_message(
|
||||
messages: List[AllMessageValues],
|
||||
user_continue_message: Optional[ChatCompletionUserMessage],
|
||||
ensure_alternating_roles: bool,
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Inserts a user continue message into the messages list.
|
||||
Handles three cases:
|
||||
1. Initial assistant message
|
||||
2. Final assistant message
|
||||
3. Consecutive assistant messages
|
||||
|
||||
Only inserts messages between consecutive assistant messages,
|
||||
ignoring all other role types.
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
result_messages = messages.copy() # Don't modify the input list
|
||||
continue_message = user_continue_message or DEFAULT_USER_CONTINUE_MESSAGE
|
||||
|
||||
# Handle first message if it's an assistant message
|
||||
if result_messages[0]["role"] == "assistant":
|
||||
result_messages.insert(0, continue_message)
|
||||
|
||||
# Handle consecutive assistant messages and final message
|
||||
i = 1 # Start from second message since we handled first message
|
||||
while i < len(result_messages):
|
||||
curr_message = result_messages[i]
|
||||
prev_message = result_messages[i - 1]
|
||||
|
||||
# Only check for consecutive assistant messages
|
||||
# Ignore all other role types
|
||||
if curr_message["role"] == "assistant" and prev_message["role"] == "assistant":
|
||||
result_messages.insert(i, continue_message)
|
||||
i += 2 # Skip over the message we just inserted
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Handle final message
|
||||
if result_messages[-1]["role"] == "assistant" and ensure_alternating_roles:
|
||||
result_messages.append(continue_message)
|
||||
|
||||
return result_messages
|
||||
|
||||
|
||||
def _insert_assistant_continue_message(
|
||||
messages: List[AllMessageValues],
|
||||
assistant_continue_message: Optional[ChatCompletionAssistantMessage] = None,
|
||||
ensure_alternating_roles: bool = True,
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Add assistant continuation messages between consecutive user messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
assistant_continue_message: Optional custom assistant message
|
||||
ensure_alternating_roles: Whether to enforce alternating roles
|
||||
|
||||
Returns:
|
||||
Modified list of messages with inserted assistant messages
|
||||
"""
|
||||
if not ensure_alternating_roles or len(messages) <= 1:
|
||||
return messages
|
||||
|
||||
# Create a new list to store modified messages
|
||||
modified_messages: List[AllMessageValues] = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
modified_messages.append(message)
|
||||
|
||||
# Check if we need to insert an assistant message
|
||||
if (
|
||||
i < len(messages) - 1 # Not the last message
|
||||
and message.get("role") == "user" # Current is user
|
||||
and messages[i + 1].get("role") == "user"
|
||||
): # Next is user
|
||||
|
||||
# Insert assistant message
|
||||
continue_message = (
|
||||
assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE
|
||||
)
|
||||
modified_messages.append(continue_message)
|
||||
|
||||
return modified_messages
|
||||
|
||||
|
||||
def get_completion_messages(
|
||||
messages: List[AllMessageValues],
|
||||
assistant_continue_message: Optional[ChatCompletionAssistantMessage],
|
||||
user_continue_message: Optional[ChatCompletionUserMessage],
|
||||
ensure_alternating_roles: bool,
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Ensures messages alternate between user and assistant roles by adding placeholders
|
||||
only when there are consecutive messages of the same role.
|
||||
|
||||
1. ensure 'user' message before 1st 'assistant' message
|
||||
2. ensure 'user' message after last 'assistant' message
|
||||
"""
|
||||
if not ensure_alternating_roles:
|
||||
return messages.copy()
|
||||
|
||||
## INSERT USER CONTINUE MESSAGE
|
||||
messages = _insert_user_continue_message(
|
||||
messages, user_continue_message, ensure_alternating_roles
|
||||
)
|
||||
|
||||
## INSERT ASSISTANT CONTINUE MESSAGE
|
||||
messages = _insert_assistant_continue_message(
|
||||
messages, assistant_continue_message, ensure_alternating_roles
|
||||
)
|
||||
return messages
|
||||
|
|
|
@ -1388,9 +1388,14 @@ def anthropic_messages_pt(
|
|||
for m in user_message_types_block["content"]:
|
||||
if m.get("type", "") == "image_url":
|
||||
m = cast(ChatCompletionImageObject, m)
|
||||
image_chunk = convert_to_anthropic_image_obj(
|
||||
openai_image_url=m["image_url"]["url"] # type: ignore
|
||||
)
|
||||
if isinstance(m["image_url"], str):
|
||||
image_chunk = convert_to_anthropic_image_obj(
|
||||
openai_image_url=m["image_url"]
|
||||
)
|
||||
else:
|
||||
image_chunk = convert_to_anthropic_image_obj(
|
||||
openai_image_url=m["image_url"]["url"]
|
||||
)
|
||||
|
||||
_anthropic_content_element = AnthropicMessagesImageParam(
|
||||
type="image",
|
||||
|
|
|
@ -887,10 +887,16 @@ class VertexLLM(VertexBase):
|
|||
|
||||
if "citationMetadata" in candidate:
|
||||
citation_metadata.append(candidate["citationMetadata"])
|
||||
if "text" in candidate["content"]["parts"][0]:
|
||||
if (
|
||||
"parts" in candidate["content"]
|
||||
and "text" in candidate["content"]["parts"][0]
|
||||
):
|
||||
content_str = candidate["content"]["parts"][0]["text"]
|
||||
|
||||
if "functionCall" in candidate["content"]["parts"][0]:
|
||||
if (
|
||||
"parts" in candidate["content"]
|
||||
and "functionCall" in candidate["content"]["parts"][0]
|
||||
):
|
||||
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||
name=candidate["content"]["parts"][0]["functionCall"][
|
||||
"name"
|
||||
|
@ -1358,7 +1364,11 @@ class ModelResponseIterator:
|
|||
if _candidates and len(_candidates) > 0:
|
||||
gemini_chunk = _candidates[0]
|
||||
|
||||
if gemini_chunk and "content" in gemini_chunk:
|
||||
if (
|
||||
gemini_chunk
|
||||
and "content" in gemini_chunk
|
||||
and "parts" in gemini_chunk["content"]
|
||||
):
|
||||
if "text" in gemini_chunk["content"]["parts"][0]:
|
||||
text = gemini_chunk["content"]["parts"][0]["text"]
|
||||
elif "functionCall" in gemini_chunk["content"]["parts"][0]:
|
||||
|
|
|
@ -109,6 +109,7 @@ from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
|
|||
from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion
|
||||
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.prompt_templates.common_utils import get_completion_messages
|
||||
from .llms.prompt_templates.factory import (
|
||||
custom_prompt,
|
||||
function_call_prompt,
|
||||
|
@ -144,7 +145,11 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler im
|
|||
VertexEmbedding,
|
||||
)
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .types.llms.openai import HttpxBinaryResponseContent
|
||||
from .types.llms.openai import (
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionUserMessage,
|
||||
HttpxBinaryResponseContent,
|
||||
)
|
||||
from .types.utils import (
|
||||
AdapterCompletionStreamWrapper,
|
||||
ChatCompletionMessageToolCall,
|
||||
|
@ -748,6 +753,15 @@ def completion( # type: ignore
|
|||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
fallbacks = kwargs.get("fallbacks", None)
|
||||
headers = kwargs.get("headers", None) or extra_headers
|
||||
ensure_alternating_roles: Optional[bool] = kwargs.get(
|
||||
"ensure_alternating_roles", None
|
||||
)
|
||||
user_continue_message: Optional[ChatCompletionUserMessage] = kwargs.get(
|
||||
"user_continue_message", None
|
||||
)
|
||||
assistant_continue_message: Optional[ChatCompletionAssistantMessage] = kwargs.get(
|
||||
"assistant_continue_message", None
|
||||
)
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
|
@ -784,7 +798,12 @@ def completion( # type: ignore
|
|||
### Admin Controls ###
|
||||
no_log = kwargs.get("no-log", False)
|
||||
### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489
|
||||
messages = deepcopy(messages)
|
||||
messages = get_completion_messages(
|
||||
messages=messages,
|
||||
ensure_alternating_roles=ensure_alternating_roles or False,
|
||||
user_continue_message=user_continue_message,
|
||||
assistant_continue_message=assistant_continue_message,
|
||||
)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = [
|
||||
"functions",
|
||||
|
|
|
@ -5,8 +5,8 @@ model_list:
|
|||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
|
||||
assistant_settings:
|
||||
custom_llm_provider: azure
|
||||
litellm_params:
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
litellm_settings:
|
||||
callbacks: ["argilla"]
|
||||
argilla_transformation_object:
|
||||
user_input: "messages"
|
||||
llm_output: "response"
|
|
@ -1,13 +1,17 @@
|
|||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from litellm import Router, provider_list
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.types.router import (
|
||||
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
||||
ConfigurableClientsideParamsCustomAuth,
|
||||
)
|
||||
|
||||
|
||||
def _get_request_ip_address(
|
||||
|
@ -73,8 +77,41 @@ def check_complete_credentials(request_body: dict) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
|
||||
"""
|
||||
Check if request_body_value matches the regex_str or is equal to param
|
||||
"""
|
||||
if re.match(regex_str, request_body_value) or regex_str == request_body_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_param_allowed(
|
||||
param: str,
|
||||
request_body_value: Any,
|
||||
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if param is a str or dict and if request_body_value is in the list of allowed values
|
||||
"""
|
||||
if configurable_clientside_auth_params is None:
|
||||
return False
|
||||
|
||||
for item in configurable_clientside_auth_params:
|
||||
if isinstance(item, str) and param == item:
|
||||
return True
|
||||
elif isinstance(item, Dict):
|
||||
if param == "api_base" and check_regex_or_str_match(
|
||||
request_body_value=request_body_value,
|
||||
regex_str=item["api_base"],
|
||||
): # assume param is a regex
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _allow_model_level_clientside_configurable_parameters(
|
||||
model: str, param: str, llm_router: Optional[Router]
|
||||
model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if model is allowed to use configurable client-side params
|
||||
|
@ -99,10 +136,11 @@ def _allow_model_level_clientside_configurable_parameters(
|
|||
if model_info is None or model_info.configurable_clientside_auth_params is None:
|
||||
return False
|
||||
|
||||
if param in model_info.configurable_clientside_auth_params:
|
||||
return True
|
||||
|
||||
return False
|
||||
return _is_param_allowed(
|
||||
param=param,
|
||||
request_body_value=request_body_value,
|
||||
configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
|
||||
)
|
||||
|
||||
|
||||
def is_request_body_safe(
|
||||
|
@ -127,7 +165,10 @@ def is_request_body_safe(
|
|||
return True
|
||||
elif (
|
||||
_allow_model_level_clientside_configurable_parameters(
|
||||
model=model, param=param, llm_router=llm_router
|
||||
model=model,
|
||||
param=param,
|
||||
request_body_value=request_body[param],
|
||||
llm_router=llm_router,
|
||||
)
|
||||
is True
|
||||
):
|
||||
|
|
24
litellm/proxy/management_endpoints/sso_helper_utils.py
Normal file
24
litellm/proxy/management_endpoints/sso_helper_utils.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
from fastapi import HTTPException
|
||||
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
|
||||
def check_is_admin_only_access(ui_access_mode: str) -> bool:
|
||||
"""Checks ui access mode is admin_only"""
|
||||
return ui_access_mode == "admin_only"
|
||||
|
||||
|
||||
def has_admin_ui_access(user_role: str) -> bool:
|
||||
"""
|
||||
Check if the user has admin access to the UI.
|
||||
|
||||
Returns:
|
||||
bool: True if user is 'proxy_admin' or 'proxy_admin_view_only', False otherwise.
|
||||
"""
|
||||
|
||||
if (
|
||||
user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
and user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
|
||||
):
|
||||
return False
|
||||
return True
|
|
@ -31,6 +31,10 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
|||
show_missing_vars_in_env,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
||||
from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||
check_is_admin_only_access,
|
||||
has_admin_ui_access,
|
||||
)
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -545,17 +549,16 @@ async def auth_callback(request: Request):
|
|||
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
|
||||
)
|
||||
## CHECK IF ROLE ALLOWED TO USE PROXY ##
|
||||
if ui_access_mode == "admin_only" and (
|
||||
user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
or user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
|
||||
):
|
||||
verbose_proxy_logger.debug("EXCEPTION RAISED")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}"
|
||||
},
|
||||
)
|
||||
is_admin_only_access = check_is_admin_only_access(ui_access_mode)
|
||||
if is_admin_only_access:
|
||||
has_access = has_admin_ui_access(user_role)
|
||||
if not has_access:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}"
|
||||
},
|
||||
)
|
||||
|
||||
import jwt
|
||||
|
||||
|
|
|
@ -1124,6 +1124,7 @@ class PrismaClient:
|
|||
"MonthlyGlobalSpendPerKey",
|
||||
"MonthlyGlobalSpendPerUserPerKey",
|
||||
"Last30dTopEndUsersSpend",
|
||||
"DailyTagSpend",
|
||||
]
|
||||
required_view = "LiteLLM_VerificationTokenView"
|
||||
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
|
||||
|
|
|
@ -104,6 +104,7 @@ from litellm.types.llms.openai import (
|
|||
Thread,
|
||||
)
|
||||
from litellm.types.router import (
|
||||
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
||||
SPECIAL_MODEL_INFO_PARAMS,
|
||||
VALID_LITELLM_ENVIRONMENTS,
|
||||
AlertingConfig,
|
||||
|
@ -4183,7 +4184,7 @@ class Router:
|
|||
|
||||
total_tpm: Optional[int] = None
|
||||
total_rpm: Optional[int] = None
|
||||
configurable_clientside_auth_params: Optional[List[str]] = None
|
||||
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
|
||||
|
||||
for model in self.model_list:
|
||||
is_match = False
|
||||
|
|
|
@ -65,7 +65,7 @@ class HttpxPartType(TypedDict, total=False):
|
|||
|
||||
class HttpxContentType(TypedDict, total=False):
|
||||
role: Literal["user", "model"]
|
||||
parts: Required[List[HttpxPartType]]
|
||||
parts: List[HttpxPartType]
|
||||
|
||||
|
||||
class ContentType(TypedDict, total=False):
|
||||
|
|
|
@ -5,10 +5,11 @@ litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
|
|||
import datetime
|
||||
import enum
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from ..exceptions import RateLimitError
|
||||
from .completion import CompletionRequest
|
||||
|
@ -16,6 +17,15 @@ from .embedding import EmbeddingRequest
|
|||
from .utils import ModelResponse
|
||||
|
||||
|
||||
class ConfigurableClientsideParamsCustomAuth(TypedDict):
|
||||
api_base: str
|
||||
|
||||
|
||||
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = Optional[
|
||||
List[Union[str, ConfigurableClientsideParamsCustomAuth]]
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_name: str
|
||||
litellm_params: Union[CompletionRequest, EmbeddingRequest]
|
||||
|
@ -139,7 +149,7 @@ class GenericLiteLLMParams(BaseModel):
|
|||
)
|
||||
max_retries: Optional[int] = None
|
||||
organization: Optional[str] = None # for openai orgs
|
||||
configurable_clientside_auth_params: Optional[List[str]] = None
|
||||
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
|
||||
## UNIFIED PROJECT/REGION ##
|
||||
region_name: Optional[str] = None
|
||||
## VERTEX AI ##
|
||||
|
@ -311,9 +321,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
|||
stream_timeout: Optional[Union[float, str]]
|
||||
max_retries: Optional[int]
|
||||
organization: Optional[Union[List, str]] # for openai orgs
|
||||
configurable_clientside_auth_params: Optional[
|
||||
List[str]
|
||||
] # for allowing api base switching on finetuned models
|
||||
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS # for allowing api base switching on finetuned models
|
||||
## DROP PARAMS ##
|
||||
drop_params: Optional[bool]
|
||||
## UNIFIED PROJECT/REGION ##
|
||||
|
@ -496,7 +504,7 @@ class ModelGroupInfo(BaseModel):
|
|||
supports_vision: bool = Field(default=False)
|
||||
supports_function_calling: bool = Field(default=False)
|
||||
supported_openai_params: Optional[List[str]] = Field(default=[])
|
||||
configurable_clientside_auth_params: Optional[List[str]] = None
|
||||
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
|
||||
|
||||
|
||||
class AssistantsTypedDict(TypedDict):
|
||||
|
|
|
@ -1246,6 +1246,9 @@ all_litellm_params = [
|
|||
"user_continue_message",
|
||||
"configurable_clientside_auth_params",
|
||||
"weight",
|
||||
"ensure_alternating_roles",
|
||||
"assistant_continue_message",
|
||||
"user_continue_message",
|
||||
]
|
||||
|
||||
|
||||
|
|
70
tests/local_testing/test_auth_utils.py
Normal file
70
tests/local_testing/test_auth_utils.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
# What is this?
|
||||
## Tests if proxy/auth/auth_utils.py works as expected
|
||||
|
||||
import sys, os, asyncio, time, random, uuid
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm.proxy.auth.auth_utils import (
|
||||
_allow_model_level_clientside_configurable_parameters,
|
||||
)
|
||||
from litellm.router import Router
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allowed_param, input_value, should_return_true",
|
||||
[
|
||||
("api_base", {"api_base": "http://dummy.com"}, True),
|
||||
(
|
||||
{"api_base": "https://api.openai.com/v1"},
|
||||
{"api_base": "https://api.openai.com/v1"},
|
||||
True,
|
||||
), # should return True
|
||||
(
|
||||
{"api_base": "https://api.openai.com/v1"},
|
||||
{"api_base": "https://api.anthropic.com/v1"},
|
||||
False,
|
||||
), # should return False
|
||||
(
|
||||
{"api_base": "^https://litellm.*direct\.fireworks\.ai/v1$"},
|
||||
{"api_base": "https://litellm-dev.direct.fireworks.ai/v1"},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{"api_base": "^https://litellm.*novice\.fireworks\.ai/v1$"},
|
||||
{"api_base": "https://litellm-dev.direct.fireworks.ai/v1"},
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_configurable_clientside_parameters(
|
||||
allowed_param, input_value, should_return_true
|
||||
):
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "dummy-model",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "dummy-key",
|
||||
"configurable_clientside_auth_params": [allowed_param],
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
resp = _allow_model_level_clientside_configurable_parameters(
|
||||
model="dummy-model",
|
||||
param="api_base",
|
||||
request_body_value=input_value["api_base"],
|
||||
llm_router=router,
|
||||
)
|
||||
print(resp)
|
||||
assert resp == should_return_true
|
|
@ -22,9 +22,13 @@ from litellm.llms.prompt_templates.factory import (
|
|||
llama_2_chat_pt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.prompt_templates.common_utils import (
|
||||
get_completion_messages,
|
||||
)
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
|
||||
_gemini_convert_messages_with_history,
|
||||
)
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
def test_llama_3_prompt():
|
||||
|
@ -457,3 +461,217 @@ def test_azure_tool_call_invoke_helper():
|
|||
"function_call": {"name": "get_weather", "arguments": ""},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"messages, expected_messages, user_continue_message, assistant_continue_message",
|
||||
[
|
||||
(
|
||||
[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
{"role": "user", "content": "What is Databricks?"},
|
||||
{"role": "user", "content": "What is Azure?"},
|
||||
{"role": "assistant", "content": "I don't know anyything, do you?"},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?",
|
||||
},
|
||||
{"role": "user", "content": "What is Databricks?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Please continue.",
|
||||
},
|
||||
{"role": "user", "content": "What is Azure?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I don't know anyything, do you?",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please continue.",
|
||||
},
|
||||
],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "user", "content": "What is Databricks?"},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Please continue."},
|
||||
{"role": "user", "content": "What is Databricks?"},
|
||||
],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "user", "content": "What is Databricks?"},
|
||||
{"role": "user", "content": "What is Azure?"},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Please continue."},
|
||||
{"role": "user", "content": "What is Databricks?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Please continue.",
|
||||
},
|
||||
{"role": "user", "content": "What is Azure?"},
|
||||
],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?",
|
||||
},
|
||||
{"role": "user", "content": "What is Databricks?"},
|
||||
{"role": "user", "content": "What is Azure?"},
|
||||
{"role": "assistant", "content": "I don't know anyything, do you?"},
|
||||
{"role": "assistant", "content": "I can't repeat sentences."},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?",
|
||||
},
|
||||
{"role": "user", "content": "What is Databricks?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Please continue",
|
||||
},
|
||||
{"role": "user", "content": "What is Azure?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I don't know anyything, do you?",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Ok",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I can't repeat sentences.",
|
||||
},
|
||||
{"role": "user", "content": "Ok"},
|
||||
],
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Ok",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Please continue",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_ensure_alternating_roles(
|
||||
messages, expected_messages, user_continue_message, assistant_continue_message
|
||||
):
|
||||
|
||||
messages = get_completion_messages(
|
||||
messages=messages,
|
||||
assistant_continue_message=assistant_continue_message,
|
||||
user_continue_message=user_continue_message,
|
||||
ensure_alternating_roles=True,
|
||||
)
|
||||
|
||||
print(messages)
|
||||
|
||||
assert messages == expected_messages
|
||||
|
||||
|
||||
def test_alternating_roles_e2e():
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
import json
|
||||
|
||||
litellm.set_verbose = True
|
||||
http_handler = HTTPHandler()
|
||||
|
||||
with patch.object(http_handler, "post", new=MagicMock()) as mock_post:
|
||||
response = litellm.completion(
|
||||
**{
|
||||
"model": "databricks/databricks-meta-llama-3-1-70b-instruct",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?",
|
||||
},
|
||||
{"role": "user", "content": "What is Databricks?"},
|
||||
{"role": "user", "content": "What is Azure?"},
|
||||
{"role": "assistant", "content": "I don't know anyything, do you?"},
|
||||
{"role": "assistant", "content": "I can't repeat sentences."},
|
||||
],
|
||||
"user_continue_message": {
|
||||
"role": "user",
|
||||
"content": "Ok",
|
||||
},
|
||||
"assistant_continue_message": {
|
||||
"role": "assistant",
|
||||
"content": "Please continue",
|
||||
},
|
||||
"ensure_alternating_roles": True,
|
||||
},
|
||||
client=http_handler,
|
||||
)
|
||||
print(f"response: {response}")
|
||||
assert mock_post.call_args.kwargs["data"] == json.dumps(
|
||||
{
|
||||
"model": "databricks-meta-llama-3-1-70b-instruct",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?",
|
||||
},
|
||||
{"role": "user", "content": "What is Databricks?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Please continue",
|
||||
},
|
||||
{"role": "user", "content": "What is Azure?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I don't know anyything, do you?",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Ok",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I can't repeat sentences.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Ok",
|
||||
},
|
||||
],
|
||||
"stream": False,
|
||||
}
|
||||
)
|
||||
|
|
38
tests/local_testing/test_ui_sso_helper_utils.py
Normal file
38
tests/local_testing/test_ui_sso_helper_utils.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
# What is this?
|
||||
## This tests the batch update spend logic on the proxy server
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
|
||||
load_dotenv()
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import logging
|
||||
from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||
check_is_admin_only_access,
|
||||
has_admin_ui_access,
|
||||
)
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
|
||||
def test_check_is_admin_only_access():
|
||||
assert check_is_admin_only_access("admin_only") is True
|
||||
assert check_is_admin_only_access("user_only") is False
|
||||
|
||||
|
||||
def test_has_admin_ui_access():
|
||||
assert has_admin_ui_access(LitellmUserRoles.PROXY_ADMIN.value) is True
|
||||
assert has_admin_ui_access(LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value) is True
|
||||
assert has_admin_ui_access(LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value) is False
|
Loading…
Add table
Add a link
Reference in a new issue