diff --git a/docker-compose.yml b/docker-compose.yml index ca98ec784..be84462ef 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,7 +15,7 @@ services: ports: - "4000:4000" # Map the container port to the host, change the host port if necessary environment: - DATABASE_URL: "postgresql://postgres:example@db:5432/postgres" + DATABASE_URL: "postgresql://llmproxy:dbpassword9090@db:5432/litellm" STORE_MODEL_IN_DB: "True" # allows adding models to proxy via UI env_file: - .env # Load local .env file @@ -25,11 +25,13 @@ services: image: postgres restart: always environment: - POSTGRES_PASSWORD: example + POSTGRES_DB: litellm + POSTGRES_USER: llmproxy + POSTGRES_PASSWORD: dbpassword9090 healthcheck: - test: ["CMD-SHELL", "pg_isready"] + test: ["CMD-SHELL", "pg_isready -d litellm -U llmproxy"] interval: 1s timeout: 5s retries: 10 -# ...rest of your docker-compose config if any \ No newline at end of file +# ...rest of your docker-compose config if any diff --git a/docs/my-website/docs/observability/arize_integration.md b/docs/my-website/docs/observability/arize_integration.md new file mode 100644 index 000000000..d2592da6a --- /dev/null +++ b/docs/my-website/docs/observability/arize_integration.md @@ -0,0 +1,72 @@ +import Image from '@theme/IdealImage'; + +# 🔥 Arize AI - Logging LLM Input/Output + +AI Observability and Evaluation Platform + +:::tip + +This is community maintained, Please make an issue if you run into a bug +https://github.com/BerriAI/litellm + +::: + + + +## Pre-Requisites +Make an account on [Arize AI](https://app.arize.com/auth/login) + +## Quick Start +Use just 2 lines of code, to instantly log your responses **across all providers** with arize + + +```python +litellm.callbacks = ["arize"] +``` +```python +import litellm +import os + +os.environ["ARIZE_SPACE_KEY"] = "" +os.environ["ARIZE_API_KEY"] = "" # defaults to litellm-completion + +# LLM API Keys +os.environ['OPENAI_API_KEY']="" + +# set arize as a callback, litellm will send the data to arize +litellm.callbacks = ["arize"] + +# openai call +response = litellm.completion( + model="gpt-3.5-turbo", + messages=[ + {"role": "user", "content": "Hi 👋 - i'm openai"} + ] +) +``` + +### Using with LiteLLM Proxy + + +```yaml +model_list: + - model_name: gpt-4 + litellm_params: + model: openai/fake + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + +litellm_settings: + callbacks: ["arize"] + +environment_variables: + ARIZE_SPACE_KEY: "d0*****" + ARIZE_API_KEY: "141a****" +``` + +## Support & Talk to Founders + +- [Schedule Demo 👋](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version) +- [Community Discord 💭](https://discord.gg/wuPM9dRgDw) +- Our numbers 📞 +1 (770) 8783-106 / ‭+1 (412) 618-6238‬ +- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai diff --git a/docs/my-website/docs/observability/langsmith_integration.md b/docs/my-website/docs/observability/langsmith_integration.md index 79d047e33..d57a64f09 100644 --- a/docs/my-website/docs/observability/langsmith_integration.md +++ b/docs/my-website/docs/observability/langsmith_integration.md @@ -1,6 +1,6 @@ import Image from '@theme/IdealImage'; -# Langsmith - Logging LLM Input/Output +# 🦜 Langsmith - Logging LLM Input/Output :::tip diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index 0d5016645..34e153750 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -48,6 +48,20 @@ A number of these headers could be useful for troubleshooting, but the `x-litellm-call-id` is the one that is most useful for tracking a request across components in your system, including in logging tools. +## Redacting UserAPIKeyInfo + +Redact information about the user api key (hashed token, user_id, team id, etc.), from logs. + +Currently supported for Langfuse, OpenTelemetry, Logfire, ArizeAI logging. + +```yaml +litellm_settings: + callbacks: ["langfuse"] + redact_user_api_key_info: true +``` + +Removes any field with `user_api_key_*` from metadata. + ## Logging Proxy Input/Output - Langfuse We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse. Make sure to set `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` in your environment @@ -202,6 +216,9 @@ print(response) ### Team based Logging to Langfuse +[👉 Tutorial - Allow each team to use their own Langfuse Project / custom callbacks](team_logging) + ### Redacting Messages, Response Content from Langfuse Logging @@ -1106,6 +1123,52 @@ environment_variables: ``` +2. Start Proxy + +``` +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl --location 'http://0.0.0.0:4000/chat/completions' \ +--header 'Content-Type: application/json' \ +--data ' { + "model": "fake-openai-endpoint", + "messages": [ + { + "role": "user", + "content": "Hello, Claude gm!" + } + ], + } +' +``` +Expect to see your log on Langfuse + + + +## Logging LLM IO to Arize AI + +1. Set `success_callback: ["arize"]` on litellm config.yaml + +```yaml +model_list: + - model_name: gpt-4 + litellm_params: + model: openai/fake + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + +litellm_settings: + callbacks: ["arize"] + +environment_variables: + ARIZE_SPACE_KEY: "d0*****" + ARIZE_API_KEY: "141a****" +``` + 2. Start Proxy ``` diff --git a/docs/my-website/docs/proxy/team_based_routing.md b/docs/my-website/docs/proxy/team_based_routing.md index 6a68e5a1f..6254abaf5 100644 --- a/docs/my-website/docs/proxy/team_based_routing.md +++ b/docs/my-website/docs/proxy/team_based_routing.md @@ -71,7 +71,13 @@ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \ }' ``` +## Team Based Logging +[👉 Tutorial - Allow each team to use their own Langfuse Project / custom callbacks](team_logging.md) + + + + diff --git a/docs/my-website/docs/proxy/team_logging.md b/docs/my-website/docs/proxy/team_logging.md new file mode 100644 index 000000000..c3758a7c7 --- /dev/null +++ b/docs/my-website/docs/proxy/team_logging.md @@ -0,0 +1,84 @@ +import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# 👥📊 Team Based Logging + +Allow each team to use their own Langfuse Project / custom callbacks + +**This allows you to do the following** +``` +Team 1 -> Logs to Langfuse Project 1 +Team 2 -> Logs to Langfuse Project 2 +Team 3 -> Logs to Langsmith +``` + +## Quick Start + +## 1. Set callback for team + +```shell +curl -X POST 'http:/localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/callback' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "callback_name": "langfuse", + "callback_type": "success", + "callback_vars": { + "langfuse_public_key": "pk", + "langfuse_secret_key": "sk_", + "langfuse_host": "https://cloud.langfuse.com" + } + +}' +``` + +#### Supported Values + +| Field | Supported Values | Notes | +|-------|------------------|-------| +| `callback_name` | `"langfuse"` | Currently only supports "langfuse" | +| `callback_type` | `"success"`, `"failure"`, `"success_and_failure"` | | +| `callback_vars` | | dict of callback settings | +|     `langfuse_public_key` | string | Required | +|     `langfuse_secret_key` | string | Required | +|     `langfuse_host` | string | Optional (defaults to https://cloud.langfuse.com) | + +## 2. Create key for team + +All keys created for team `dbe2f686-a686-4896-864a-4c3924458709` will log to langfuse project specified on [Step 1. Set callback for team](#1-set-callback-for-team) + + +```shell +curl --location 'http://0.0.0.0:4000/key/generate' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "team_id": "dbe2f686-a686-4896-864a-4c3924458709" +}' +``` + + +## 3. Make `/chat/completion` request for team + +```shell +curl -i http://localhost:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-KbUuE0WNptC0jXapyMmLBA" \ + -d '{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "Hello, Claude gm!"} + ] +}' +``` + +Expect this to be logged on the langfuse project specified on [Step 1. Set callback for team](#1-set-callback-for-team) + +## Team Logging Endpoints + +- [`POST /team/{team_id}/callback` Add a success/failure callback to a team](https://litellm-api.up.railway.app/#/team%20management/add_team_callbacks_team__team_id__callback_post) +- [`GET /team/{team_id}/callback` - Get the success/failure callbacks and variables for a team](https://litellm-api.up.railway.app/#/team%20management/get_team_callbacks_team__team_id__callback_get) + + + diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 8d77bd85f..54df1f3e3 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -44,19 +44,20 @@ const sidebars = { "proxy/cost_tracking", "proxy/self_serve", "proxy/virtual_keys", - "proxy/tag_routing", - "proxy/users", - "proxy/team_budgets", - "proxy/customers", - "proxy/billing", - "proxy/guardrails", - "proxy/token_auth", - "proxy/alerting", { type: "category", label: "🪢 Logging", items: ["proxy/logging", "proxy/streaming_logging"], }, + "proxy/team_logging", + "proxy/guardrails", + "proxy/tag_routing", + "proxy/users", + "proxy/team_budgets", + "proxy/customers", + "proxy/billing", + "proxy/token_auth", + "proxy/alerting", "proxy/ui", "proxy/prometheus", "proxy/pass_through", @@ -192,6 +193,8 @@ const sidebars = { items: [ "observability/langfuse_integration", "observability/logfire_integration", + "observability/langsmith_integration", + "observability/arize_integration", "debugging/local_debugging", "observability/raw_request_response", "observability/custom_callback", @@ -203,7 +206,6 @@ const sidebars = { "observability/openmeter", "observability/promptlayer_integration", "observability/wandb_integration", - "observability/langsmith_integration", "observability/slack_integration", "observability/traceloop_integration", "observability/athina_integration", diff --git a/litellm/__init__.py b/litellm/__init__.py index bf3f77385..9bb9a81cd 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -4,7 +4,7 @@ import warnings warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*") ### INIT VARIABLES ### import threading, requests, os -from typing import Callable, List, Optional, Dict, Union, Any, Literal +from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.caching import Cache from litellm._logging import ( @@ -45,7 +45,11 @@ _custom_logger_compatible_callbacks_literal = Literal[ "langsmith", "galileo", "braintrust", + "arize", ] +_known_custom_logger_compatible_callbacks: List = list( + get_args(_custom_logger_compatible_callbacks_literal) +) callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = [] _langfuse_default_tags: Optional[ List[ @@ -73,6 +77,7 @@ post_call_rules: List[Callable] = [] turn_off_message_logging: Optional[bool] = False log_raw_request_response: bool = False redact_messages_in_exceptions: Optional[bool] = False +redact_user_api_key_info: Optional[bool] = False store_audit_logs = False # Enterprise feature, allow users to see audit logs ## end of callbacks ############# diff --git a/litellm/integrations/_types/open_inference.py b/litellm/integrations/_types/open_inference.py new file mode 100644 index 000000000..bcfabe9b7 --- /dev/null +++ b/litellm/integrations/_types/open_inference.py @@ -0,0 +1,286 @@ +from enum import Enum + + +class SpanAttributes: + OUTPUT_VALUE = "output.value" + OUTPUT_MIME_TYPE = "output.mime_type" + """ + The type of output.value. If unspecified, the type is plain text by default. + If type is JSON, the value is a string representing a JSON object. + """ + INPUT_VALUE = "input.value" + INPUT_MIME_TYPE = "input.mime_type" + """ + The type of input.value. If unspecified, the type is plain text by default. + If type is JSON, the value is a string representing a JSON object. + """ + + EMBEDDING_EMBEDDINGS = "embedding.embeddings" + """ + A list of objects containing embedding data, including the vector and represented piece of text. + """ + EMBEDDING_MODEL_NAME = "embedding.model_name" + """ + The name of the embedding model. + """ + + LLM_FUNCTION_CALL = "llm.function_call" + """ + For models and APIs that support function calling. Records attributes such as the function + name and arguments to the called function. + """ + LLM_INVOCATION_PARAMETERS = "llm.invocation_parameters" + """ + Invocation parameters passed to the LLM or API, such as the model name, temperature, etc. + """ + LLM_INPUT_MESSAGES = "llm.input_messages" + """ + Messages provided to a chat API. + """ + LLM_OUTPUT_MESSAGES = "llm.output_messages" + """ + Messages received from a chat API. + """ + LLM_MODEL_NAME = "llm.model_name" + """ + The name of the model being used. + """ + LLM_PROMPTS = "llm.prompts" + """ + Prompts provided to a completions API. + """ + LLM_PROMPT_TEMPLATE = "llm.prompt_template.template" + """ + The prompt template as a Python f-string. + """ + LLM_PROMPT_TEMPLATE_VARIABLES = "llm.prompt_template.variables" + """ + A list of input variables to the prompt template. + """ + LLM_PROMPT_TEMPLATE_VERSION = "llm.prompt_template.version" + """ + The version of the prompt template being used. + """ + LLM_TOKEN_COUNT_PROMPT = "llm.token_count.prompt" + """ + Number of tokens in the prompt. + """ + LLM_TOKEN_COUNT_COMPLETION = "llm.token_count.completion" + """ + Number of tokens in the completion. + """ + LLM_TOKEN_COUNT_TOTAL = "llm.token_count.total" + """ + Total number of tokens, including both prompt and completion. + """ + + TOOL_NAME = "tool.name" + """ + Name of the tool being used. + """ + TOOL_DESCRIPTION = "tool.description" + """ + Description of the tool's purpose, typically used to select the tool. + """ + TOOL_PARAMETERS = "tool.parameters" + """ + Parameters of the tool represented a dictionary JSON string, e.g. + see https://platform.openai.com/docs/guides/gpt/function-calling + """ + + RETRIEVAL_DOCUMENTS = "retrieval.documents" + + METADATA = "metadata" + """ + Metadata attributes are used to store user-defined key-value pairs. + For example, LangChain uses metadata to store user-defined attributes for a chain. + """ + + TAG_TAGS = "tag.tags" + """ + Custom categorical tags for the span. + """ + + OPENINFERENCE_SPAN_KIND = "openinference.span.kind" + + SESSION_ID = "session.id" + """ + The id of the session + """ + USER_ID = "user.id" + """ + The id of the user + """ + + +class MessageAttributes: + """ + Attributes for a message sent to or from an LLM + """ + + MESSAGE_ROLE = "message.role" + """ + The role of the message, such as "user", "agent", "function". + """ + MESSAGE_CONTENT = "message.content" + """ + The content of the message to or from the llm, must be a string. + """ + MESSAGE_CONTENTS = "message.contents" + """ + The message contents to the llm, it is an array of + `message_content` prefixed attributes. + """ + MESSAGE_NAME = "message.name" + """ + The name of the message, often used to identify the function + that was used to generate the message. + """ + MESSAGE_TOOL_CALLS = "message.tool_calls" + """ + The tool calls generated by the model, such as function calls. + """ + MESSAGE_FUNCTION_CALL_NAME = "message.function_call_name" + """ + The function name that is a part of the message list. + This is populated for role 'function' or 'agent' as a mechanism to identify + the function that was called during the execution of a tool. + """ + MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = "message.function_call_arguments_json" + """ + The JSON string representing the arguments passed to the function + during a function call. + """ + + +class MessageContentAttributes: + """ + Attributes for the contents of user messages sent to an LLM. + """ + + MESSAGE_CONTENT_TYPE = "message_content.type" + """ + The type of the content, such as "text" or "image". + """ + MESSAGE_CONTENT_TEXT = "message_content.text" + """ + The text content of the message, if the type is "text". + """ + MESSAGE_CONTENT_IMAGE = "message_content.image" + """ + The image content of the message, if the type is "image". + An image can be made available to the model by passing a link to + the image or by passing the base64 encoded image directly in the + request. + """ + + +class ImageAttributes: + """ + Attributes for images + """ + + IMAGE_URL = "image.url" + """ + An http or base64 image url + """ + + +class DocumentAttributes: + """ + Attributes for a document. + """ + + DOCUMENT_ID = "document.id" + """ + The id of the document. + """ + DOCUMENT_SCORE = "document.score" + """ + The score of the document + """ + DOCUMENT_CONTENT = "document.content" + """ + The content of the document. + """ + DOCUMENT_METADATA = "document.metadata" + """ + The metadata of the document represented as a dictionary + JSON string, e.g. `"{ 'title': 'foo' }"` + """ + + +class RerankerAttributes: + """ + Attributes for a reranker + """ + + RERANKER_INPUT_DOCUMENTS = "reranker.input_documents" + """ + List of documents as input to the reranker + """ + RERANKER_OUTPUT_DOCUMENTS = "reranker.output_documents" + """ + List of documents as output from the reranker + """ + RERANKER_QUERY = "reranker.query" + """ + Query string for the reranker + """ + RERANKER_MODEL_NAME = "reranker.model_name" + """ + Model name of the reranker + """ + RERANKER_TOP_K = "reranker.top_k" + """ + Top K parameter of the reranker + """ + + +class EmbeddingAttributes: + """ + Attributes for an embedding + """ + + EMBEDDING_TEXT = "embedding.text" + """ + The text represented by the embedding. + """ + EMBEDDING_VECTOR = "embedding.vector" + """ + The embedding vector. + """ + + +class ToolCallAttributes: + """ + Attributes for a tool call + """ + + TOOL_CALL_FUNCTION_NAME = "tool_call.function.name" + """ + The name of function that is being called during a tool call. + """ + TOOL_CALL_FUNCTION_ARGUMENTS_JSON = "tool_call.function.arguments" + """ + The JSON string representing the arguments passed to the function + during a tool call. + """ + + +class OpenInferenceSpanKindValues(Enum): + TOOL = "TOOL" + CHAIN = "CHAIN" + LLM = "LLM" + RETRIEVER = "RETRIEVER" + EMBEDDING = "EMBEDDING" + AGENT = "AGENT" + RERANKER = "RERANKER" + UNKNOWN = "UNKNOWN" + GUARDRAIL = "GUARDRAIL" + EVALUATOR = "EVALUATOR" + + +class OpenInferenceMimeTypeValues(Enum): + TEXT = "text/plain" + JSON = "application/json" diff --git a/litellm/integrations/arize_ai.py b/litellm/integrations/arize_ai.py new file mode 100644 index 000000000..45c6c1604 --- /dev/null +++ b/litellm/integrations/arize_ai.py @@ -0,0 +1,114 @@ +""" +arize AI is OTEL compatible + +this file has Arize ai specific helper functions +""" + +from typing import TYPE_CHECKING, Any, Optional, Union + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +def set_arize_ai_attributes(span: Span, kwargs, response_obj): + from litellm.integrations._types.open_inference import ( + MessageAttributes, + MessageContentAttributes, + OpenInferenceSpanKindValues, + SpanAttributes, + ) + + optional_params = kwargs.get("optional_params", {}) + litellm_params = kwargs.get("litellm_params", {}) or {} + + ############################################# + ############ LLM CALL METADATA ############## + ############################################# + # commented out for now - looks like Arize AI could not log this + # metadata = litellm_params.get("metadata", {}) or {} + # span.set_attribute(SpanAttributes.METADATA, str(metadata)) + + ############################################# + ########## LLM Request Attributes ########### + ############################################# + + # The name of the LLM a request is being made to + if kwargs.get("model"): + span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model")) + + span.set_attribute( + SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.LLM.value + ) + messages = kwargs.get("messages") + + # for /chat/completions + # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions + if messages: + span.set_attribute( + SpanAttributes.INPUT_VALUE, + messages[-1].get("content", ""), # get the last message for input + ) + + # LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page + for idx, msg in enumerate(messages): + # Set the role per message + span.set_attribute( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}", + msg["role"], + ) + # Set the content per message + span.set_attribute( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}", + msg.get("content", ""), + ) + + # The Generative AI Provider: Azure, OpenAI, etc. + span.set_attribute(SpanAttributes.LLM_INVOCATION_PARAMETERS, str(optional_params)) + + if optional_params.get("user"): + span.set_attribute(SpanAttributes.USER_ID, optional_params.get("user")) + + ############################################# + ########## LLM Response Attributes ########## + # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions + ############################################# + for choice in response_obj.get("choices"): + response_message = choice.get("message", {}) + span.set_attribute( + SpanAttributes.OUTPUT_VALUE, response_message.get("content", "") + ) + + # This shows up under `output_messages` tab on the span page + # This code assumes a single response + span.set_attribute( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}", + response_message["role"], + ) + span.set_attribute( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}", + response_message.get("content", ""), + ) + + usage = response_obj.get("usage") + if usage: + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_TOTAL, + usage.get("total_tokens"), + ) + + # The number of tokens used in the LLM response (completion). + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, + usage.get("completion_tokens"), + ) + + # The number of tokens used in the LLM prompt. + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_PROMPT, + usage.get("prompt_tokens"), + ) + pass diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index 0647afabc..0217f7458 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -8,6 +8,7 @@ from packaging.version import Version import litellm from litellm._logging import verbose_logger +from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info class LangFuseLogger: @@ -382,6 +383,8 @@ class LangFuseLogger: mask_input = clean_metadata.pop("mask_input", False) mask_output = clean_metadata.pop("mask_output", False) + clean_metadata = redact_user_api_key_info(metadata=clean_metadata) + if trace_name is None and existing_trace_id is None: # just log `litellm-{call_type}` as the trace name ## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces. diff --git a/litellm/integrations/logfire_logger.py b/litellm/integrations/logfire_logger.py index b4ab00820..fa4ab7bd5 100644 --- a/litellm/integrations/logfire_logger.py +++ b/litellm/integrations/logfire_logger.py @@ -1,17 +1,21 @@ #### What this does #### # On success + failure, log events to Logfire -import dotenv, os +import os + +import dotenv dotenv.load_dotenv() # Loading env variables using dotenv import traceback import uuid -from litellm._logging import print_verbose, verbose_logger - from enum import Enum from typing import Any, Dict, NamedTuple + from typing_extensions import LiteralString +from litellm._logging import print_verbose, verbose_logger +from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info + class SpanConfig(NamedTuple): message_template: LiteralString @@ -135,6 +139,8 @@ class LogfireLogger: else: clean_metadata[key] = value + clean_metadata = redact_user_api_key_info(metadata=clean_metadata) + # Build the initial payload payload = { "id": id, diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 215a4f09f..c47911b4f 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -2,11 +2,12 @@ import os from dataclasses import dataclass from datetime import datetime from functools import wraps -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union import litellm from litellm._logging import verbose_logger from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info from litellm.types.services import ServiceLoggerPayload if TYPE_CHECKING: @@ -27,9 +28,10 @@ else: LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm") -LITELLM_RESOURCE = { +LITELLM_RESOURCE: Dict[Any, Any] = { "service.name": os.getenv("OTEL_SERVICE_NAME", "litellm"), "deployment.environment": os.getenv("OTEL_ENVIRONMENT_NAME", "production"), + "model_id": os.getenv("OTEL_SERVICE_NAME", "litellm"), } RAW_REQUEST_SPAN_NAME = "raw_gen_ai_request" LITELLM_REQUEST_SPAN_NAME = "litellm_request" @@ -68,7 +70,9 @@ class OpenTelemetryConfig: class OpenTelemetry(CustomLogger): - def __init__(self, config=OpenTelemetryConfig.from_env()): + def __init__( + self, config=OpenTelemetryConfig.from_env(), callback_name: Optional[str] = None + ): from opentelemetry import trace from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider @@ -79,6 +83,7 @@ class OpenTelemetry(CustomLogger): self.OTEL_HEADERS = self.config.headers provider = TracerProvider(resource=Resource(attributes=LITELLM_RESOURCE)) provider.add_span_processor(self._get_span_processor()) + self.callback_name = callback_name trace.set_tracer_provider(provider) self.tracer = trace.get_tracer(LITELLM_TRACER_NAME) @@ -120,8 +125,8 @@ class OpenTelemetry(CustomLogger): from opentelemetry import trace from opentelemetry.trace import Status, StatusCode - _start_time_ns = start_time - _end_time_ns = end_time + _start_time_ns = 0 + _end_time_ns = 0 if isinstance(start_time, float): _start_time_ns = int(int(start_time) * 1e9) @@ -159,8 +164,8 @@ class OpenTelemetry(CustomLogger): from opentelemetry import trace from opentelemetry.trace import Status, StatusCode - _start_time_ns = start_time - _end_time_ns = end_time + _start_time_ns = 0 + _end_time_ns = 0 if isinstance(start_time, float): _start_time_ns = int(int(start_time) * 1e9) @@ -294,6 +299,11 @@ class OpenTelemetry(CustomLogger): return isinstance(value, (str, bool, int, float)) def set_attributes(self, span: Span, kwargs, response_obj): + if self.callback_name == "arize": + from litellm.integrations.arize_ai import set_arize_ai_attributes + + set_arize_ai_attributes(span, kwargs, response_obj) + return from litellm.proxy._types import SpanAttributes optional_params = kwargs.get("optional_params", {}) @@ -306,7 +316,9 @@ class OpenTelemetry(CustomLogger): ############################################# metadata = litellm_params.get("metadata", {}) or {} - for key, value in metadata.items(): + clean_metadata = redact_user_api_key_info(metadata=metadata) + + for key, value in clean_metadata.items(): if self.is_primitive(value): span.set_attribute("metadata.{}".format(key), value) @@ -612,8 +624,8 @@ class OpenTelemetry(CustomLogger): from opentelemetry import trace from opentelemetry.trace import Status, StatusCode - _start_time_ns = logging_payload.start_time - _end_time_ns = logging_payload.end_time + _start_time_ns = 0 + _end_time_ns = 0 start_time = logging_payload.start_time end_time = logging_payload.end_time @@ -658,8 +670,8 @@ class OpenTelemetry(CustomLogger): from opentelemetry import trace from opentelemetry.trace import Status, StatusCode - _start_time_ns = logging_payload.start_time - _end_time_ns = logging_payload.end_time + _start_time_ns = 0 + _end_time_ns = 0 start_time = logging_payload.start_time end_time = logging_payload.end_time diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 17837c41e..0785933aa 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1962,6 +1962,43 @@ def _init_custom_logger_compatible_class( _langsmith_logger = LangsmithLogger() _in_memory_loggers.append(_langsmith_logger) return _langsmith_logger # type: ignore + elif logging_integration == "arize": + if "ARIZE_SPACE_KEY" not in os.environ: + raise ValueError("ARIZE_SPACE_KEY not found in environment variables") + if "ARIZE_API_KEY" not in os.environ: + raise ValueError("ARIZE_API_KEY not found in environment variables") + from litellm.integrations.opentelemetry import ( + OpenTelemetry, + OpenTelemetryConfig, + ) + + otel_config = OpenTelemetryConfig( + exporter="otlp_grpc", + endpoint="https://otlp.arize.com/v1", + ) + os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( + f"space_key={os.getenv('ARIZE_SPACE_KEY')},api_key={os.getenv('ARIZE_API_KEY')}" + ) + for callback in _in_memory_loggers: + if ( + isinstance(callback, OpenTelemetry) + and callback.callback_name == "arize" + ): + return callback # type: ignore + _otel_logger = OpenTelemetry(config=otel_config, callback_name="arize") + _in_memory_loggers.append(_otel_logger) + return _otel_logger # type: ignore + + elif logging_integration == "otel": + from litellm.integrations.opentelemetry import OpenTelemetry + + for callback in _in_memory_loggers: + if isinstance(callback, OpenTelemetry): + return callback # type: ignore + + otel_logger = OpenTelemetry() + _in_memory_loggers.append(otel_logger) + return otel_logger # type: ignore elif logging_integration == "galileo": for callback in _in_memory_loggers: @@ -2039,6 +2076,25 @@ def get_custom_logger_compatible_class( for callback in _in_memory_loggers: if isinstance(callback, LangsmithLogger): return callback + elif logging_integration == "otel": + from litellm.integrations.opentelemetry import OpenTelemetry + + for callback in _in_memory_loggers: + if isinstance(callback, OpenTelemetry): + return callback + elif logging_integration == "arize": + from litellm.integrations.opentelemetry import OpenTelemetry + + if "ARIZE_SPACE_KEY" not in os.environ: + raise ValueError("ARIZE_SPACE_KEY not found in environment variables") + if "ARIZE_API_KEY" not in os.environ: + raise ValueError("ARIZE_API_KEY not found in environment variables") + for callback in _in_memory_loggers: + if ( + isinstance(callback, OpenTelemetry) + and callback.callback_name == "arize" + ): + return callback elif logging_integration == "logfire": if "LOGFIRE_TOKEN" not in os.environ: raise ValueError("LOGFIRE_TOKEN not found in environment variables") diff --git a/litellm/litellm_core_utils/redact_messages.py b/litellm/litellm_core_utils/redact_messages.py index 378c46ba0..7f342e271 100644 --- a/litellm/litellm_core_utils/redact_messages.py +++ b/litellm/litellm_core_utils/redact_messages.py @@ -87,3 +87,33 @@ def redact_message_input_output_from_logging( # by default return result return result + + +def redact_user_api_key_info(metadata: dict) -> dict: + """ + removes any user_api_key_info before passing to logging object, if flag set + + Usage: + + SDK + ```python + litellm.redact_user_api_key_info = True + ``` + + PROXY: + ```yaml + litellm_settings: + redact_user_api_key_info: true + ``` + """ + if litellm.redact_user_api_key_info is not True: + return metadata + + new_metadata = {} + for k, v in metadata.items(): + if isinstance(k, str) and k.startswith("user_api_key"): + pass + else: + new_metadata[k] = v + + return new_metadata diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index d1e0d14ba..25e2e518c 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -968,7 +968,7 @@ class OpenAIChatCompletion(BaseLLM): except openai.UnprocessableEntityError as e: ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 if litellm.drop_params is True or drop_params is True: - if e.body is not None and e.body.get("detail"): # type: ignore + if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore detail = e.body.get("detail") # type: ignore invalid_params: List[str] = [] if ( @@ -1100,7 +1100,7 @@ class OpenAIChatCompletion(BaseLLM): except openai.UnprocessableEntityError as e: ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 if litellm.drop_params is True or drop_params is True: - if e.body is not None and e.body.get("detail"): # type: ignore + if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore detail = e.body.get("detail") # type: ignore invalid_params: List[str] = [] if ( @@ -1231,7 +1231,7 @@ class OpenAIChatCompletion(BaseLLM): except openai.UnprocessableEntityError as e: ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 if litellm.drop_params is True or drop_params is True: - if e.body is not None and e.body.get("detail"): # type: ignore + if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore detail = e.body.get("detail") # type: ignore invalid_params: List[str] = [] if ( diff --git a/litellm/main.py b/litellm/main.py index 8cb52d945..4e2df72cd 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1491,6 +1491,10 @@ def completion( or get_secret("ANTHROPIC_BASE_URL") or "https://api.anthropic.com/v1/complete" ) + + if api_base is not None and not api_base.endswith("/v1/complete"): + api_base += "/v1/complete" + response = anthropic_text_completions.completion( model=model, messages=messages, @@ -1517,6 +1521,10 @@ def completion( or get_secret("ANTHROPIC_BASE_URL") or "https://api.anthropic.com/v1/messages" ) + + if api_base is not None and not api_base.endswith("/v1/messages"): + api_base += "/v1/messages" + response = anthropic_chat_completions.completion( model=model, messages=messages, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 7a35650e5..16570cbe1 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -4,6 +4,6 @@ model_list: model: groq/llama3-groq-70b-8192-tool-use-preview api_key: os.environ/GROQ_API_KEY - litellm_settings: - callbacks: ["braintrust"] + callbacks: ["logfire"] + redact_user_api_key_info: true diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e9371c1d8..25aa942e5 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -228,6 +228,10 @@ class LiteLLMRoutes(enum.Enum): "/utils/token_counter", ] + anthropic_routes: List = [ + "/v1/messages", + ] + info_routes: List = [ "/key/info", "/team/info", @@ -880,6 +884,26 @@ class BlockTeamRequest(LiteLLMBase): team_id: str # required +class AddTeamCallback(LiteLLMBase): + callback_name: str + callback_type: Literal["success", "failure", "success_and_failure"] + # for now - only supported for langfuse + callback_vars: Dict[ + Literal["langfuse_public_key", "langfuse_secret_key", "langfuse_host"], str + ] + + +class TeamCallbackMetadata(LiteLLMBase): + success_callback: Optional[List[str]] = [] + failure_callback: Optional[List[str]] = [] + # for now - only supported for langfuse + callback_vars: Optional[ + Dict[ + Literal["langfuse_public_key", "langfuse_secret_key", "langfuse_host"], str + ] + ] = {} + + class LiteLLM_TeamTable(TeamBase): spend: Optional[float] = None max_parallel_requests: Optional[int] = None @@ -1232,6 +1256,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): soft_budget: Optional[float] = None team_model_aliases: Optional[Dict] = None team_member_spend: Optional[float] = None + team_metadata: Optional[Dict] = None # End User Params end_user_id: Optional[str] = None @@ -1677,3 +1702,5 @@ class ProxyErrorTypes(str, enum.Enum): budget_exceeded = "budget_exceeded" expired_key = "expired_key" auth_error = "auth_error" + internal_server_error = "internal_server_error" + bad_request_error = "bad_request_error" diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 96171f2ef..91d4b1938 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -24,7 +24,7 @@ from litellm.proxy._types import ( LitellmUserRoles, UserAPIKeyAuth, ) -from litellm.proxy.auth.auth_utils import is_openai_route +from litellm.proxy.auth.auth_utils import is_llm_api_route from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry from litellm.types.services import ServiceLoggerPayload, ServiceTypes @@ -57,6 +57,7 @@ def common_checks( 4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget 5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget + 7. [OPTIONAL] If guardrails modified - is request allowed to change this """ _model = request_body.get("model", None) if team_object is not None and team_object.blocked is True: @@ -106,7 +107,7 @@ def common_checks( general_settings.get("enforce_user_param", None) is not None and general_settings["enforce_user_param"] == True ): - if is_openai_route(route=route) and "user" not in request_body: + if is_llm_api_route(route=route) and "user" not in request_body: raise Exception( f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" ) @@ -122,7 +123,7 @@ def common_checks( + CommonProxyErrors.not_premium_user.value ) - if is_openai_route(route=route): + if is_llm_api_route(route=route): # loop through each enforced param # example enforced_params ['user', 'metadata', 'metadata.generation_name'] for enforced_param in general_settings["enforced_params"]: @@ -150,7 +151,7 @@ def common_checks( and global_proxy_spend is not None # only run global budget checks for OpenAI routes # Reason - the Admin UI should continue working if the proxy crosses it's global budget - and is_openai_route(route=route) + and is_llm_api_route(route=route) and route != "/v1/models" and route != "/models" ): @@ -158,6 +159,22 @@ def common_checks( raise litellm.BudgetExceededError( current_cost=global_proxy_spend, max_budget=litellm.max_budget ) + + _request_metadata: dict = request_body.get("metadata", {}) or {} + if _request_metadata.get("guardrails"): + # check if team allowed to modify guardrails + from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails + + can_modify: bool = can_modify_guardrails(team_object) + if can_modify is False: + from fastapi import HTTPException + + raise HTTPException( + status_code=403, + detail={ + "error": "Your team does not have permission to modify guardrails." + }, + ) return True diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index d3e030762..bd1e50ed0 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -46,7 +46,7 @@ def route_in_additonal_public_routes(current_route: str): return False -def is_openai_route(route: str) -> bool: +def is_llm_api_route(route: str) -> bool: """ Helper to checks if provided route is an OpenAI route @@ -59,6 +59,9 @@ def is_openai_route(route: str) -> bool: if route in LiteLLMRoutes.openai_routes.value: return True + if route in LiteLLMRoutes.anthropic_routes.value: + return True + # fuzzy match routes like "/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ" # Check for routes with placeholders for openai_route in LiteLLMRoutes.openai_routes.value: diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index c5549ffcb..d91baf5ca 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -57,7 +57,7 @@ from litellm.proxy.auth.auth_checks import ( log_to_opentelemetry, ) from litellm.proxy.auth.auth_utils import ( - is_openai_route, + is_llm_api_route, route_in_additonal_public_routes, ) from litellm.proxy.common_utils.http_parsing_utils import _read_request_body @@ -924,6 +924,7 @@ async def user_api_key_auth( rpm_limit=valid_token.team_rpm_limit, blocked=valid_token.team_blocked, models=valid_token.team_models, + metadata=valid_token.team_metadata, ) user_api_key_cache.set_cache( @@ -994,9 +995,9 @@ async def user_api_key_auth( _user_role = _get_user_role(user_id_information=user_id_information) if not _is_user_proxy_admin(user_id_information): # if non-admin - if is_openai_route(route=route): + if is_llm_api_route(route=route): pass - elif is_openai_route(route=request["route"].name): + elif is_llm_api_route(route=request["route"].name): pass elif ( route in LiteLLMRoutes.info_routes.value @@ -1049,7 +1050,7 @@ async def user_api_key_auth( pass elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value: - if is_openai_route(route=route): + if is_llm_api_route(route=route): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"user not allowed to access this OpenAI routes, role= {_user_role}", diff --git a/litellm/proxy/common_utils/init_callbacks.py b/litellm/proxy/common_utils/init_callbacks.py index 2fcceaa29..eaa926fed 100644 --- a/litellm/proxy/common_utils/init_callbacks.py +++ b/litellm/proxy/common_utils/init_callbacks.py @@ -23,12 +23,11 @@ def initialize_callbacks_on_proxy( ) if isinstance(value, list): imported_list: List[Any] = [] - known_compatible_callbacks = list( - get_args(litellm._custom_logger_compatible_callbacks_literal) - ) for callback in value: # ["presidio", ] - - if isinstance(callback, str) and callback in known_compatible_callbacks: + if ( + isinstance(callback, str) + and callback in litellm._known_custom_logger_compatible_callbacks + ): imported_list.append(callback) elif isinstance(callback, str) and callback == "otel": from litellm.integrations.opentelemetry import OpenTelemetry diff --git a/litellm/proxy/guardrails/guardrail_helpers.py b/litellm/proxy/guardrails/guardrail_helpers.py index d6a081b4d..e0a5f1eb3 100644 --- a/litellm/proxy/guardrails/guardrail_helpers.py +++ b/litellm/proxy/guardrails/guardrail_helpers.py @@ -1,9 +1,26 @@ +from typing import Dict + import litellm from litellm._logging import verbose_proxy_logger -from litellm.proxy.proxy_server import UserAPIKeyAuth +from litellm.proxy.proxy_server import LiteLLM_TeamTable, UserAPIKeyAuth from litellm.types.guardrails import * +def can_modify_guardrails(team_obj: Optional[LiteLLM_TeamTable]) -> bool: + if team_obj is None: + return True + + team_metadata = team_obj.metadata or {} + + if team_metadata.get("guardrails", None) is not None and isinstance( + team_metadata.get("guardrails"), Dict + ): + if team_metadata.get("guardrails", {}).get("modify_guardrails", None) is False: + return False + + return True + + async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> bool: """ checks if this guardrail should be applied to this call diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 1014a325a..642c12616 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional from fastapi import Request from litellm._logging import verbose_logger, verbose_proxy_logger -from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth +from litellm.proxy._types import CommonProxyErrors, TeamCallbackMetadata, UserAPIKeyAuth from litellm.types.utils import SupportedCacheControls if TYPE_CHECKING: @@ -207,6 +207,29 @@ async def add_litellm_data_to_request( **data, } # add the team-specific configs to the completion call + # Team Callbacks controls + if user_api_key_dict.team_metadata is not None: + team_metadata = user_api_key_dict.team_metadata + if "callback_settings" in team_metadata: + callback_settings = team_metadata.get("callback_settings", None) or {} + callback_settings_obj = TeamCallbackMetadata(**callback_settings) + """ + callback_settings = { + { + 'callback_vars': {'langfuse_public_key': 'pk', 'langfuse_secret_key': 'sk_'}, + 'failure_callback': [], + 'success_callback': ['langfuse', 'langfuse'] + } + } + """ + data["success_callback"] = callback_settings_obj.success_callback + data["failure_callback"] = callback_settings_obj.failure_callback + + if callback_settings_obj.callback_vars is not None: + # unpack callback_vars in data + for k, v in callback_settings_obj.callback_vars.items(): + data[k] = v + return data diff --git a/litellm/proxy/management_endpoints/team_callback_endpoints.py b/litellm/proxy/management_endpoints/team_callback_endpoints.py new file mode 100644 index 000000000..9c2ac65cc --- /dev/null +++ b/litellm/proxy/management_endpoints/team_callback_endpoints.py @@ -0,0 +1,279 @@ +""" +Endpoints to control callbacks per team + +Use this when each team should control its own callbacks +""" + +import asyncio +import copy +import json +import traceback +import uuid +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +import fastapi +from fastapi import APIRouter, Depends, Header, HTTPException, Request, status + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + AddTeamCallback, + LiteLLM_TeamTable, + ProxyErrorTypes, + ProxyException, + TeamCallbackMetadata, + UserAPIKeyAuth, +) +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.management_helpers.utils import ( + add_new_member, + management_endpoint_wrapper, +) + +router = APIRouter() + + +@router.post( + "/team/{team_id:path}/callback", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def add_team_callbacks( + data: AddTeamCallback, + http_request: Request, + team_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Add a success/failure callback to a team + + Use this if if you want different teams to have different success/failure callbacks + + Example curl: + ``` + curl -X POST 'http:/localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/callback' \ + -H 'Content-Type: application/json' \ + -H 'Authorization: Bearer sk-1234' \ + -d '{ + "callback_name": "langfuse", + "callback_type": "success", + "callback_vars": {"langfuse_public_key": "pk-lf-xxxx1", "langfuse_secret_key": "sk-xxxxx"} + + }' + ``` + + This means for the team where team_id = dbe2f686-a686-4896-864a-4c3924458709, all LLM calls will be logged to langfuse using the public key pk-lf-xxxx1 and the secret key sk-xxxxx + + """ + try: + from litellm.proxy.proxy_server import ( + _duration_in_seconds, + create_audit_log_for_update, + litellm_proxy_admin_name, + prisma_client, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + # Check if team_id exists already + _existing_team = await prisma_client.get_data( + team_id=team_id, table_name="team", query_type="find_unique" + ) + if _existing_team is None: + raise HTTPException( + status_code=400, + detail={ + "error": f"Team id = {team_id} does not exist. Please use a different team id." + }, + ) + + # store team callback settings in metadata + team_metadata = _existing_team.metadata + team_callback_settings = team_metadata.get("callback_settings", {}) + # expect callback settings to be + team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings) + if data.callback_type == "success": + if team_callback_settings_obj.success_callback is None: + team_callback_settings_obj.success_callback = [] + + if data.callback_name in team_callback_settings_obj.success_callback: + raise ProxyException( + message=f"callback_name = {data.callback_name} already exists in failure_callback, for team_id = {team_id}. \n Existing failure_callback = {team_callback_settings_obj.success_callback}", + code=status.HTTP_400_BAD_REQUEST, + type=ProxyErrorTypes.bad_request_error, + param="callback_name", + ) + + team_callback_settings_obj.success_callback.append(data.callback_name) + elif data.callback_type == "failure": + if team_callback_settings_obj.failure_callback is None: + team_callback_settings_obj.failure_callback = [] + + if data.callback_name in team_callback_settings_obj.failure_callback: + raise ProxyException( + message=f"callback_name = {data.callback_name} already exists in failure_callback, for team_id = {team_id}. \n Existing failure_callback = {team_callback_settings_obj.failure_callback}", + code=status.HTTP_400_BAD_REQUEST, + type=ProxyErrorTypes.bad_request_error, + param="callback_name", + ) + team_callback_settings_obj.failure_callback.append(data.callback_name) + elif data.callback_type == "success_and_failure": + if team_callback_settings_obj.success_callback is None: + team_callback_settings_obj.success_callback = [] + if team_callback_settings_obj.failure_callback is None: + team_callback_settings_obj.failure_callback = [] + if data.callback_name in team_callback_settings_obj.success_callback: + raise ProxyException( + message=f"callback_name = {data.callback_name} already exists in success_callback, for team_id = {team_id}. \n Existing success_callback = {team_callback_settings_obj.success_callback}", + code=status.HTTP_400_BAD_REQUEST, + type=ProxyErrorTypes.bad_request_error, + param="callback_name", + ) + + if data.callback_name in team_callback_settings_obj.failure_callback: + raise ProxyException( + message=f"callback_name = {data.callback_name} already exists in failure_callback, for team_id = {team_id}. \n Existing failure_callback = {team_callback_settings_obj.failure_callback}", + code=status.HTTP_400_BAD_REQUEST, + type=ProxyErrorTypes.bad_request_error, + param="callback_name", + ) + + team_callback_settings_obj.success_callback.append(data.callback_name) + team_callback_settings_obj.failure_callback.append(data.callback_name) + for var, value in data.callback_vars.items(): + if team_callback_settings_obj.callback_vars is None: + team_callback_settings_obj.callback_vars = {} + team_callback_settings_obj.callback_vars[var] = value + + team_callback_settings_obj_dict = team_callback_settings_obj.model_dump() + + team_metadata["callback_settings"] = team_callback_settings_obj_dict + team_metadata_json = json.dumps(team_metadata) # update team_metadata + + new_team_row = await prisma_client.db.litellm_teamtable.update( + where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore + ) + + return { + "status": "success", + "data": new_team_row, + } + + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.add_team_callbacks(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Internal Server Error({str(e)})"), + type=ProxyErrorTypes.internal_server_error.value, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Internal Server Error, " + str(e), + type=ProxyErrorTypes.internal_server_error.value, + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +@router.get( + "/team/{team_id:path}/callback", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def get_team_callbacks( + http_request: Request, + team_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Get the success/failure callbacks and variables for a team + + Example curl: + ``` + curl -X GET 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/callback' \ + -H 'Authorization: Bearer sk-1234' + ``` + + This will return the callback settings for the team with id dbe2f686-a686-4896-864a-4c3924458709 + + Returns { + "status": "success", + "data": { + "team_id": team_id, + "success_callbacks": team_callback_settings_obj.success_callback, + "failure_callbacks": team_callback_settings_obj.failure_callback, + "callback_vars": team_callback_settings_obj.callback_vars, + }, + } + """ + try: + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + # Check if team_id exists + _existing_team = await prisma_client.get_data( + team_id=team_id, table_name="team", query_type="find_unique" + ) + if _existing_team is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team id = {team_id} does not exist."}, + ) + + # Retrieve team callback settings from metadata + team_metadata = _existing_team.metadata + team_callback_settings = team_metadata.get("callback_settings", {}) + + # Convert to TeamCallbackMetadata object for consistent structure + team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings) + + return { + "status": "success", + "data": { + "team_id": team_id, + "success_callbacks": team_callback_settings_obj.success_callback, + "failure_callbacks": team_callback_settings_obj.failure_callback, + "callback_vars": team_callback_settings_obj.callback_vars, + }, + } + + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.get_team_callbacks(): Exception occurred - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Internal Server Error({str(e)})"), + type=ProxyErrorTypes.internal_server_error.value, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Internal Server Error, " + str(e), + type=ProxyErrorTypes.internal_server_error.value, + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index bb98a02ec..9ba76a203 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -363,6 +363,7 @@ async def update_team( # set the budget_reset_at in DB updated_kv["budget_reset_at"] = reset_at + updated_kv = prisma_client.jsonify_object(data=updated_kv) team_row: Optional[ LiteLLM_TeamTable ] = await prisma_client.db.litellm_teamtable.update( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index c114db25f..2508a48a1 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,10 +1,15 @@ model_list: + - model_name: gpt-4 + litellm_params: + model: openai/fake + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ - model_name: fireworks-llama-v3-70b-instruct litellm_params: model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct - api_key: "os.environ/FIREWORKS_AI_API_KEY" - -router_settings: - enable_tag_filtering: True # 👈 Key Change + api_key: "os.environ/FIREWORKS" general_settings: - master_key: sk-1234 \ No newline at end of file + master_key: sk-1234 + +litellm_settings: + callbacks: ["arize"] \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 79f25c6e1..3ab864381 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -170,6 +170,9 @@ from litellm.proxy.management_endpoints.key_management_endpoints import ( from litellm.proxy.management_endpoints.key_management_endpoints import ( router as key_management_router, ) +from litellm.proxy.management_endpoints.team_callback_endpoints import ( + router as team_callback_router, +) from litellm.proxy.management_endpoints.team_endpoints import router as team_router from litellm.proxy.openai_files_endpoints.files_endpoints import ( router as openai_files_router, @@ -9457,3 +9460,4 @@ app.include_router(analytics_router) app.include_router(debugging_endpoints_router) app.include_router(ui_crud_endpoints_router) app.include_router(openai_files_router) +app.include_router(team_callback_router) diff --git a/litellm/proxy/tests/test_anthropic_sdk.py b/litellm/proxy/tests/test_anthropic_sdk.py new file mode 100644 index 000000000..073fafb07 --- /dev/null +++ b/litellm/proxy/tests/test_anthropic_sdk.py @@ -0,0 +1,22 @@ +import os + +from anthropic import Anthropic + +client = Anthropic( + # This is the default and can be omitted + base_url="http://localhost:4000", + # this is a litellm proxy key :) - not a real anthropic key + api_key="sk-s4xN1IiLTCytwtZFJaYQrA", +) + +message = client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], + model="claude-3-opus-20240229", +) +print(message.content) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0f87e962a..a982c6cd7 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -25,7 +25,7 @@ from typing_extensions import overload import litellm import litellm.litellm_core_utils import litellm.litellm_core_utils.litellm_logging -from litellm import EmbeddingResponse, ImageResponse, ModelResponse +from litellm import EmbeddingResponse, ImageResponse, ModelResponse, get_litellm_params from litellm._logging import verbose_proxy_logger from litellm._service_logger import ServiceLogging, ServiceTypes from litellm.caching import DualCache, RedisCache @@ -50,7 +50,7 @@ from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler, ) -from litellm.types.utils import CallTypes +from litellm.types.utils import CallTypes, LoggedLiteLLMParams if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -602,14 +602,20 @@ class ProxyLogging: if litellm_logging_obj is not None: ## UPDATE LOGGING INPUT _optional_params = {} + _litellm_params = {} + + litellm_param_keys = LoggedLiteLLMParams.__annotations__.keys() for k, v in request_data.items(): - if k != "model" and k != "user" and k != "litellm_params": + if k in litellm_param_keys: + _litellm_params[k] = v + elif k != "model" and k != "user": _optional_params[k] = v + litellm_logging_obj.update_environment_variables( model=request_data.get("model", ""), user=request_data.get("user", ""), optional_params=_optional_params, - litellm_params=request_data.get("litellm_params", {}), + litellm_params=_litellm_params, ) input: Union[list, str, dict] = "" @@ -1313,8 +1319,10 @@ class PrismaClient: t.tpm_limit AS team_tpm_limit, t.rpm_limit AS team_rpm_limit, t.models AS team_models, + t.metadata AS team_metadata, t.blocked AS team_blocked, t.team_alias AS team_alias, + t.metadata AS team_metadata, tm.spend AS team_member_spend, m.aliases as team_model_aliases FROM "LiteLLM_VerificationToken" AS v diff --git a/litellm/tests/test_arize_ai.py b/litellm/tests/test_arize_ai.py new file mode 100644 index 000000000..dfc00446e --- /dev/null +++ b/litellm/tests/test_arize_ai.py @@ -0,0 +1,29 @@ +import asyncio +import logging +import os +import time + +import pytest +from dotenv import load_dotenv +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig + +load_dotenv() +import logging + + +@pytest.mark.asyncio() +async def test_async_otel_callback(): + litellm.set_verbose = True + litellm.success_callback = ["arize"] + + await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi test from local arize"}], + mock_response="hello", + temperature=0.1, + user="OTEL_USER", + ) diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index eae0412d3..9c18899a5 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -234,6 +234,7 @@ class CompletionCustomHandler( ) assert isinstance(kwargs["optional_params"], dict) assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["litellm_params"]["metadata"], Optional[dict]) assert isinstance(kwargs["start_time"], (datetime, type(None))) assert isinstance(kwargs["stream"], bool) assert isinstance(kwargs["user"], (str, type(None))) diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index 94ece7305..66c8594bb 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -64,6 +64,30 @@ async def test_content_policy_exception_azure(): pytest.fail(f"An exception occurred - {str(e)}") +@pytest.mark.asyncio +async def test_content_policy_exception_openai(): + try: + # this is ony a test - we needed some way to invoke the exception :( + litellm.set_verbose = True + response = await litellm.acompletion( + model="gpt-3.5-turbo-0613", + stream=True, + messages=[ + {"role": "user", "content": "Gimme the lyrics to Don't Stop Me Now"} + ], + ) + async for chunk in response: + print(chunk) + except litellm.ContentPolicyViolationError as e: + print("caught a content policy violation error! Passed") + print("exception", e) + assert e.llm_provider == "openai" + pass + except Exception as e: + print() + pytest.fail(f"An exception occurred - {str(e)}") + + # Test 1: Context Window Errors @pytest.mark.skip(reason="AWS Suspended Account") @pytest.mark.parametrize("model", exception_models) diff --git a/litellm/tests/test_proxy_routes.py b/litellm/tests/test_proxy_routes.py index 776ad1e78..6f5774d3e 100644 --- a/litellm/tests/test_proxy_routes.py +++ b/litellm/tests/test_proxy_routes.py @@ -19,7 +19,7 @@ import pytest import litellm from litellm.proxy._types import LiteLLMRoutes -from litellm.proxy.auth.auth_utils import is_openai_route +from litellm.proxy.auth.auth_utils import is_llm_api_route from litellm.proxy.proxy_server import app # Configure logging @@ -77,8 +77,8 @@ def test_routes_on_litellm_proxy(): ("/v1/non_existent_endpoint", False), ], ) -def test_is_openai_route(route: str, expected: bool): - assert is_openai_route(route) == expected +def test_is_llm_api_route(route: str, expected: bool): + assert is_llm_api_route(route) == expected # Test case for routes that are similar but should return False @@ -91,5 +91,10 @@ def test_is_openai_route(route: str, expected: bool): "/engines/model/invalid/completions", ], ) -def test_is_openai_route_similar_but_false(route: str): - assert is_openai_route(route) == False +def test_is_llm_api_route_similar_but_false(route: str): + assert is_llm_api_route(route) == False + + +def test_anthropic_api_routes(): + # allow non proxy admins to call anthropic api routes + assert is_llm_api_route(route="/v1/messages") is True diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index ed7451c27..f3cb69a08 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -173,6 +173,63 @@ def test_chat_completion(mock_acompletion, client_no_auth): pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") +@mock_patch_acompletion() +@pytest.mark.asyncio +async def test_team_disable_guardrails(mock_acompletion, client_no_auth): + """ + If team not allowed to turn on/off guardrails + + Raise 403 forbidden error, if request is made by team on `/key/generate` or `/chat/completions`. + """ + import asyncio + import json + import time + + from fastapi import HTTPException, Request + from starlette.datastructures import URL + + from litellm.proxy._types import LiteLLM_TeamTable, ProxyException, UserAPIKeyAuth + from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + from litellm.proxy.proxy_server import hash_token, user_api_key_cache + + _team_id = "1234" + user_key = "sk-12345678" + + valid_token = UserAPIKeyAuth( + team_id=_team_id, + team_blocked=True, + token=hash_token(user_key), + last_refreshed_at=time.time(), + ) + await asyncio.sleep(1) + team_obj = LiteLLM_TeamTable( + team_id=_team_id, + blocked=False, + last_refreshed_at=time.time(), + metadata={"guardrails": {"modify_guardrails": False}}, + ) + user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) + user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj) + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world") + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + body = {"metadata": {"guardrails": {"hide_secrets": False}}} + json_bytes = json.dumps(body).encode("utf-8") + + request._body = json_bytes + + try: + await user_api_key_auth(request=request, api_key="Bearer " + user_key) + pytest.fail("Expected to raise 403 forbidden error.") + except ProxyException as e: + assert e.code == 403 + + from litellm.tests.test_custom_callback_input import CompletionCustomHandler diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 6581fea5f..88bfa19e9 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1029,3 +1029,22 @@ class GenericImageParsingChunk(TypedDict): class ResponseFormatChunk(TypedDict, total=False): type: Required[Literal["json_object", "text"]] response_schema: dict + + +class LoggedLiteLLMParams(TypedDict, total=False): + force_timeout: Optional[float] + custom_llm_provider: Optional[str] + api_base: Optional[str] + litellm_call_id: Optional[str] + model_alias_map: Optional[dict] + metadata: Optional[dict] + model_info: Optional[dict] + proxy_server_request: Optional[dict] + acompletion: Optional[bool] + preset_cache_key: Optional[str] + no_log: Optional[bool] + input_cost_per_second: Optional[float] + input_cost_per_token: Optional[float] + output_cost_per_token: Optional[float] + output_cost_per_second: Optional[float] + cooldown_time: Optional[float] diff --git a/litellm/utils.py b/litellm/utils.py index 5ec7b52f5..9d798f119 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -158,6 +158,7 @@ from typing import ( Tuple, Union, cast, + get_args, ) from .caching import Cache @@ -405,7 +406,6 @@ def function_setup( # Pop the async items from input_callback in reverse order to avoid index issues for index in reversed(removed_async_items): litellm.input_callback.pop(index) - if len(litellm.success_callback) > 0: removed_async_items = [] for index, callback in enumerate(litellm.success_callback): # type: ignore @@ -417,9 +417,9 @@ def function_setup( # we only support async dynamo db logging for acompletion/aembedding since that's used on proxy litellm._async_success_callback.append(callback) removed_async_items.append(index) - elif callback == "langsmith": + elif callback in litellm._known_custom_logger_compatible_callbacks: callback_class = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore - callback, internal_usage_cache=None, llm_router=None + callback, internal_usage_cache=None, llm_router=None # type: ignore ) # don't double add a callback @@ -8808,11 +8808,14 @@ class CustomStreamWrapper: str_line.choices[0].content_filter_result ) else: - error_message = "Azure Response={}".format( - str(dict(str_line)) + error_message = "{} Response={}".format( + self.custom_llm_provider, str(dict(str_line)) ) - raise litellm.AzureOpenAIError( - status_code=400, message=error_message + + raise litellm.ContentPolicyViolationError( + message=error_message, + llm_provider=self.custom_llm_provider, + model=self.model, ) # checking for logprobs diff --git a/pyproject.toml b/pyproject.toml index b14fb819a..5dc8ab62d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.41.26" +version = "1.41.27" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.41.26" +version = "1.41.27" version_files = [ "pyproject.toml:^version" ]