From 1c9a8c0b68be990ef3b61094344e48249d28ca47 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 22 Nov 2024 16:21:57 +0530 Subject: [PATCH 01/20] feat(pass_through_endpoints/): support logging anthropic/gemini pass through calls to langfuse/s3/etc. --- docs/my-website/sidebars.js | 40 ++++---- litellm/litellm_core_utils/litellm_logging.py | 6 +- litellm/proxy/_new_secret_config.yaml | 7 +- litellm/proxy/_types.py | 24 ++++- .../anthropic_passthrough_logging_handler.py | 39 ++++---- .../vertex_passthrough_logging_handler.py | 54 ++++++----- .../streaming_handler.py | 66 ++++++++++--- .../pass_through_endpoints/success_handler.py | 97 +++++++++++-------- litellm/proxy/utils.py | 15 ++- 9 files changed, 214 insertions(+), 134 deletions(-) diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index f01402299..a18122aa6 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -199,6 +199,29 @@ const sidebars = { ], }, + { + type: "category", + label: "Guides", + items: [ + "exception_mapping", + "completion/provider_specific_params", + "completion/audio", + "completion/vision", + "completion/json_mode", + "completion/prompt_caching", + "completion/predict_outputs", + "completion/prefix", + "completion/drop_params", + "completion/prompt_formatting", + "completion/stream", + "completion/message_trimming", + "completion/function_call", + "completion/model_alias", + "completion/batching", + "completion/mock_requests", + "completion/reliable_completions", + ] + }, { type: "category", label: "Supported Endpoints", @@ -214,25 +237,8 @@ const sidebars = { }, items: [ "completion/input", - "completion/provider_specific_params", - "completion/json_mode", - "completion/prompt_caching", - "completion/audio", - "completion/vision", - "completion/predict_outputs", - "completion/prefix", - "completion/drop_params", - "completion/prompt_formatting", "completion/output", "completion/usage", - "exception_mapping", - "completion/stream", - "completion/message_trimming", - "completion/function_call", - "completion/model_alias", - "completion/batching", - "completion/mock_requests", - "completion/reliable_completions", ], }, "embedding/supported_embedding", diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 69d6adca4..a9ded458d 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1368,8 +1368,11 @@ class Logging: and customLogger is not None ): # custom logger functions print_verbose( - "success callbacks: Running Custom Callback Function" + "success callbacks: Running Custom Callback Function - {}".format( + callback + ) ) + customLogger.log_event( kwargs=self.model_call_details, response_obj=result, @@ -2359,6 +2362,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_mlflow_logger) return _mlflow_logger # type: ignore + def get_custom_logger_compatible_class( logging_integration: litellm._custom_logger_compatible_callbacks_literal, ) -> Optional[CustomLogger]: diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 974b091cf..2c25b61db 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -12,8 +12,5 @@ model_list: vertex_ai_project: "adroit-crow-413218" vertex_ai_location: "us-east5" -router_settings: - model_group_alias: - "gpt-4-turbo": # Aliased model name - model: "gpt-4" # Actual model name in 'model_list' - hidden: true \ No newline at end of file +litellm_settings: + success_callback: ["langfuse"] \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8b8dbf2e5..b0d272f26 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -12,7 +12,15 @@ from typing_extensions import Annotated, TypedDict from litellm.types.integrations.slack_alerting import AlertType from litellm.types.router import RouterErrors, UpdateRouterConfig -from litellm.types.utils import ProviderField, StandardCallbackDynamicParams +from litellm.types.utils import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + ProviderField, + StandardCallbackDynamicParams, + StandardPassThroughResponseObject, + TextCompletionResponse, +) if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -2133,3 +2141,17 @@ class UserManagementEndpointParamDocStringEnums(str, enum.Enum): spend_doc_str = """Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used.""" team_id_doc_str = """Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None.""" duration_doc_str = """Optional[str] - Duration for the key auto-created on `/user/new`. Default is None.""" + + +PassThroughEndpointLoggingResultValues = Union[ + ModelResponse, + TextCompletionResponse, + ImageResponse, + EmbeddingResponse, + StandardPassThroughResponseObject, +] + + +class PassThroughEndpointLoggingTypedDict(TypedDict): + result: Optional[PassThroughEndpointLoggingResultValues] + kwargs: dict diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index 35cff0db3..27e7848c0 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat.handler import ( ModelResponseIterator as AnthropicModelResponseIterator, ) from litellm.llms.anthropic.chat.transformation import AnthropicConfig +from litellm.proxy._types import PassThroughEndpointLoggingTypedDict if TYPE_CHECKING: from ..success_handler import PassThroughEndpointLogging @@ -26,7 +27,7 @@ else: class AnthropicPassthroughLoggingHandler: @staticmethod - async def anthropic_passthrough_handler( + def anthropic_passthrough_handler( httpx_response: httpx.Response, response_body: dict, logging_obj: LiteLLMLoggingObj, @@ -36,7 +37,7 @@ class AnthropicPassthroughLoggingHandler: end_time: datetime, cache_hit: bool, **kwargs, - ): + ) -> PassThroughEndpointLoggingTypedDict: """ Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled """ @@ -67,15 +68,10 @@ class AnthropicPassthroughLoggingHandler: logging_obj=logging_obj, ) - await logging_obj.async_success_handler( - result=litellm_model_response, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, - ) - - pass + return { + "result": litellm_model_response, + "kwargs": kwargs, + } @staticmethod def _create_anthropic_response_logging_payload( @@ -123,7 +119,7 @@ class AnthropicPassthroughLoggingHandler: return kwargs @staticmethod - async def _handle_logging_anthropic_collected_chunks( + def _handle_logging_anthropic_collected_chunks( litellm_logging_obj: LiteLLMLoggingObj, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, @@ -132,7 +128,7 @@ class AnthropicPassthroughLoggingHandler: start_time: datetime, all_chunks: List[str], end_time: datetime, - ): + ) -> PassThroughEndpointLoggingTypedDict: """ Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks @@ -152,7 +148,10 @@ class AnthropicPassthroughLoggingHandler: verbose_proxy_logger.error( "Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..." ) - return + return { + "result": None, + "kwargs": {}, + } kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( litellm_model_response=complete_streaming_response, model=model, @@ -161,13 +160,11 @@ class AnthropicPassthroughLoggingHandler: end_time=end_time, logging_obj=litellm_logging_obj, ) - await litellm_logging_obj.async_success_handler( - result=complete_streaming_response, - start_time=start_time, - end_time=end_time, - cache_hit=False, - **kwargs, - ) + + return { + "result": complete_streaming_response, + "kwargs": kwargs, + } @staticmethod def _build_complete_streaming_response( diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index fe61f32ee..da1cf1d2a 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -14,6 +14,7 @@ from litellm.litellm_core_utils.litellm_logging import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( ModelResponseIterator as VertexModelResponseIterator, ) +from litellm.proxy._types import PassThroughEndpointLoggingTypedDict if TYPE_CHECKING: from ..success_handler import PassThroughEndpointLogging @@ -25,7 +26,7 @@ else: class VertexPassthroughLoggingHandler: @staticmethod - async def vertex_passthrough_handler( + def vertex_passthrough_handler( httpx_response: httpx.Response, logging_obj: LiteLLMLoggingObj, url_route: str, @@ -34,7 +35,7 @@ class VertexPassthroughLoggingHandler: end_time: datetime, cache_hit: bool, **kwargs, - ): + ) -> PassThroughEndpointLoggingTypedDict: if "generateContent" in url_route: model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) @@ -59,13 +60,11 @@ class VertexPassthroughLoggingHandler: logging_obj.model = litellm_model_response.model or model logging_obj.model_call_details["model"] = logging_obj.model - await logging_obj.async_success_handler( - result=litellm_model_response, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, - ) + return { + "result": litellm_model_response, + "kwargs": kwargs, + } + elif "predict" in url_route: from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import ( VertexImageGeneration, @@ -106,16 +105,18 @@ class VertexPassthroughLoggingHandler: logging_obj.model = model logging_obj.model_call_details["model"] = logging_obj.model - await logging_obj.async_success_handler( - result=litellm_prediction_response, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, - ) + return { + "result": litellm_prediction_response, + "kwargs": kwargs, + } + else: + return { + "result": None, + "kwargs": kwargs, + } @staticmethod - async def _handle_logging_vertex_collected_chunks( + def _handle_logging_vertex_collected_chunks( litellm_logging_obj: LiteLLMLoggingObj, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, @@ -124,7 +125,7 @@ class VertexPassthroughLoggingHandler: start_time: datetime, all_chunks: List[str], end_time: datetime, - ): + ) -> PassThroughEndpointLoggingTypedDict: """ Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks @@ -146,14 +147,15 @@ class VertexPassthroughLoggingHandler: verbose_proxy_logger.error( "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." ) - return - await litellm_logging_obj.async_success_handler( - result=complete_streaming_response, - start_time=start_time, - end_time=end_time, - cache_hit=False, - **kwargs, - ) + return { + "result": None, + "kwargs": kwargs, + } + + return { + "result": complete_streaming_response, + "kwargs": kwargs, + } @staticmethod def _build_complete_streaming_response( diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index 9ba5adfec..9cbc08955 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -1,5 +1,6 @@ import asyncio import json +import threading from datetime import datetime from enum import Enum from typing import AsyncIterable, Dict, List, Optional, Union @@ -15,7 +16,12 @@ from litellm.llms.anthropic.chat.handler import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( ModelResponseIterator as VertexAIIterator, ) -from litellm.types.utils import GenericStreamingChunk +from litellm.proxy._types import PassThroughEndpointLoggingResultValues +from litellm.types.utils import ( + GenericStreamingChunk, + ModelResponse, + StandardPassThroughResponseObject, +) from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( AnthropicPassthroughLoggingHandler, @@ -92,8 +98,12 @@ async def _route_streaming_logging_to_handler( - Anthropic - Vertex AI """ + standard_logging_response_object: Optional[ + PassThroughEndpointLoggingResultValues + ] = None + kwargs: dict = {} if endpoint_type == EndpointType.ANTHROPIC: - await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( + anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( litellm_logging_obj=litellm_logging_obj, passthrough_success_handler_obj=passthrough_success_handler_obj, url_route=url_route, @@ -103,17 +113,45 @@ async def _route_streaming_logging_to_handler( all_chunks=all_chunks, end_time=end_time, ) + standard_logging_response_object = anthropic_passthrough_logging_handler_result[ + "result" + ] + kwargs = anthropic_passthrough_logging_handler_result["kwargs"] elif endpoint_type == EndpointType.VERTEX_AI: - await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, + vertex_passthrough_logging_handler_result = ( + VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) ) - elif endpoint_type == EndpointType.GENERIC: - # No logging is supported for generic streaming endpoints - pass + standard_logging_response_object = vertex_passthrough_logging_handler_result[ + "result" + ] + kwargs = vertex_passthrough_logging_handler_result["kwargs"] + + if standard_logging_response_object is None: + standard_logging_response_object = StandardPassThroughResponseObject( + response=f"cannot parse chunks to standard response object. Chunks={all_chunks}" + ) + threading.Thread( + target=litellm_logging_obj.success_handler, + args=( + standard_logging_response_object, + start_time, + end_time, + False, + ), + ).start() + await litellm_logging_obj.async_success_handler( + result=standard_logging_response_object, + start_time=start_time, + end_time=end_time, + cache_hit=False, + **kwargs, + ) diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index e22a37052..c9c7707f0 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -15,6 +15,7 @@ from litellm.litellm_core_utils.litellm_logging import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) +from litellm.proxy._types import PassThroughEndpointLoggingResultValues from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.types.utils import StandardPassThroughResponseObject @@ -49,53 +50,69 @@ class PassThroughEndpointLogging: cache_hit: bool, **kwargs, ): + standard_logging_response_object: Optional[ + PassThroughEndpointLoggingResultValues + ] = None if self.is_vertex_route(url_route): - await VertexPassthroughLoggingHandler.vertex_passthrough_handler( - httpx_response=httpx_response, - logging_obj=logging_obj, - url_route=url_route, - result=result, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, + vertex_passthrough_logging_handler_result = ( + VertexPassthroughLoggingHandler.vertex_passthrough_handler( + httpx_response=httpx_response, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) ) + standard_logging_response_object = ( + vertex_passthrough_logging_handler_result["result"] + ) + kwargs = vertex_passthrough_logging_handler_result["kwargs"] elif self.is_anthropic_route(url_route): - await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( - httpx_response=httpx_response, - response_body=response_body or {}, - logging_obj=logging_obj, - url_route=url_route, - result=result, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, + anthropic_passthrough_logging_handler_result = ( + AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( + httpx_response=httpx_response, + response_body=response_body or {}, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) ) - else: + + standard_logging_response_object = ( + anthropic_passthrough_logging_handler_result["result"] + ) + kwargs = anthropic_passthrough_logging_handler_result["kwargs"] + if standard_logging_response_object is None: standard_logging_response_object = StandardPassThroughResponseObject( response=httpx_response.text ) - threading.Thread( - target=logging_obj.success_handler, - args=( - standard_logging_response_object, - start_time, - end_time, - cache_hit, - ), - ).start() - await logging_obj.async_success_handler( - result=( - json.dumps(result) - if isinstance(result, dict) - else standard_logging_response_object - ), - start_time=start_time, - end_time=end_time, - cache_hit=False, - **kwargs, - ) + threading.Thread( + target=logging_obj.success_handler, + args=( + standard_logging_response_object, + start_time, + end_time, + cache_hit, + ), + ).start() + await logging_obj.async_success_handler( + result=( + json.dumps(result) + if isinstance(result, dict) + else standard_logging_response_object + ), + start_time=start_time, + end_time=end_time, + cache_hit=False, + **kwargs, + ) def is_vertex_route(self, url_route: str): for route in self.TRACKED_VERTEX_ROUTES: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 74bf398e7..0f7d6c3e0 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -337,14 +337,14 @@ class ProxyLogging: alert_to_webhook_url=self.alert_to_webhook_url, ) - if ( - self.alerting is not None - and "slack" in self.alerting - and "daily_reports" in self.alert_types - ): + if self.alerting is not None and "slack" in self.alerting: # NOTE: ENSURE we only add callbacks when alerting is on # We should NOT add callbacks when alerting is off - litellm.callbacks.append(self.slack_alerting_instance) # type: ignore + if "daily_reports" in self.alert_types: + litellm.callbacks.append(self.slack_alerting_instance) # type: ignore + litellm.success_callback.append( + self.slack_alerting_instance.response_taking_too_long_callback + ) if redis_cache is not None: self.internal_usage_cache.dual_cache.redis_cache = redis_cache @@ -354,9 +354,6 @@ class ProxyLogging: litellm.callbacks.append(self.max_budget_limiter) # type: ignore litellm.callbacks.append(self.cache_control_check) # type: ignore litellm.callbacks.append(self.service_logging_obj) # type: ignore - litellm.success_callback.append( - self.slack_alerting_instance.response_taking_too_long_callback - ) for callback in litellm.callbacks: if isinstance(callback, str): callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore From 5a698c678a7803a968cb74ca3f93368c8423c2e7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 22 Nov 2024 16:41:58 +0530 Subject: [PATCH 02/20] fix(utils.py): allow disabling end user cost tracking with new param Allows proxy admin to disable cost tracking for end user - keeps prometheus metrics small --- litellm/__init__.py | 1 + litellm/integrations/prometheus.py | 10 ++++------ litellm/proxy/_new_secret_config.yaml | 4 +++- litellm/proxy/proxy_server.py | 4 ++-- litellm/utils.py | 10 ++++++++++ tests/local_testing/test_utils.py | 20 ++++++++++++++++++++ 6 files changed, 40 insertions(+), 9 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index c978b24ee..e6dc61dc7 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -280,6 +280,7 @@ default_max_internal_user_budget: Optional[float] = None max_internal_user_budget: Optional[float] = None internal_user_budget_duration: Optional[str] = None max_end_user_budget: Optional[float] = None +disable_end_user_cost_tracking: Optional[bool] = None #### REQUEST PRIORITIZATION #### priority_reservation: Optional[Dict[str, float]] = None #### RELIABILITY #### diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index bb28719a3..1460a1d7f 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -18,6 +18,7 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth from litellm.types.integrations.prometheus import * from litellm.types.utils import StandardLoggingPayload +from litellm.utils import get_end_user_id_for_cost_tracking class PrometheusLogger(CustomLogger): @@ -364,8 +365,7 @@ class PrometheusLogger(CustomLogger): model = kwargs.get("model", "") litellm_params = kwargs.get("litellm_params", {}) or {} _metadata = litellm_params.get("metadata", {}) - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] @@ -664,13 +664,11 @@ class PrometheusLogger(CustomLogger): # unpack kwargs model = kwargs.get("model", "") - litellm_params = kwargs.get("litellm_params", {}) or {} standard_logging_payload: StandardLoggingPayload = kwargs.get( "standard_logging_object", {} ) - proxy_server_request = litellm_params.get("proxy_server_request") or {} - - end_user_id = proxy_server_request.get("body", {}).get("user", None) + litellm_params = kwargs.get("litellm_params", {}) or {} + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 2c25b61db..f12226736 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -13,4 +13,6 @@ model_list: vertex_ai_location: "us-east5" litellm_settings: - success_callback: ["langfuse"] \ No newline at end of file + success_callback: ["langfuse"] + callbacks: ["prometheus"] + # disable_end_user_cost_tracking: true diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9d7c120a7..70bf5b523 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -268,6 +268,7 @@ from litellm.types.llms.anthropic import ( from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.router import RouterGeneralSettings from litellm.types.utils import StandardLoggingPayload +from litellm.utils import get_end_user_id_for_cost_tracking try: from litellm._version import version @@ -763,8 +764,7 @@ async def _PROXY_track_cost_callback( ) parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) litellm_params = kwargs.get("litellm_params", {}) or {} - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) user_id = metadata.get("user_api_key_user_id", None) team_id = metadata.get("user_api_key_team_id", None) diff --git a/litellm/utils.py b/litellm/utils.py index 003971142..262af3418 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6170,3 +6170,13 @@ class ProviderConfigManager: return litellm.GroqChatConfig() return OpenAIGPTConfig() + + +def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]: + """ + Used for enforcing `disable_end_user_cost_tracking` param. + """ + proxy_server_request = litellm_params.get("proxy_server_request") or {} + if litellm.disable_end_user_cost_tracking: + return None + return proxy_server_request.get("body", {}).get("user", None) diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 52946ca30..cf1db27e8 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1012,3 +1012,23 @@ def test_models_by_provider(): for provider in providers: assert provider in models_by_provider.keys() + + +@pytest.mark.parametrize( + "litellm_params, disable_end_user_cost_tracking, expected_end_user_id", + [ + ({}, False, None), + ({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"), + ({"proxy_server_request": {"body": {"user": "123"}}}, True, None), + ], +) +def test_get_end_user_id_for_cost_tracking( + litellm_params, disable_end_user_cost_tracking, expected_end_user_id +): + from litellm.utils import get_end_user_id_for_cost_tracking + + litellm.disable_end_user_cost_tracking = disable_end_user_cost_tracking + assert ( + get_end_user_id_for_cost_tracking(litellm_params=litellm_params) + == expected_end_user_id + ) From 97d8aa0b3aaaf73c70a61e6ff0930ef61cbd43f6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 22 Nov 2024 16:43:35 +0530 Subject: [PATCH 03/20] docs(configs.md): add disable_end_user_cost_tracking reference to docs --- docs/my-website/docs/proxy/configs.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 3b6b336d6..6f54df2ae 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -754,6 +754,7 @@ general_settings: | cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. | | cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) | | cache_params.mode | string | The mode of the cache. [Further docs](./caching) | +| disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. | ### general_settings - Reference From 1014216d734d5dfcfb929d69bc2c9e8c49813040 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 22 Nov 2024 22:59:01 +0530 Subject: [PATCH 04/20] feat(key_management_endpoints.py): add support for restricting access to `/key/generate` by team/proxy level role Enables admin to restrict key creation, and assign team admins to handle distributing keys --- litellm/__init__.py | 2 + litellm/proxy/_new_secret_config.yaml | 6 +- litellm/proxy/_types.py | 4 - .../key_management_endpoints.py | 73 +++++++++++++++++++ litellm/types/utils.py | 13 ++++ 5 files changed, 93 insertions(+), 5 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index e6dc61dc7..65b1b3465 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -24,6 +24,7 @@ from litellm.proxy._types import ( KeyManagementSettings, LiteLLM_UpperboundKeyGenerateParams, ) +from litellm.types.utils import StandardKeyGenerationConfig import httpx import dotenv from enum import Enum @@ -273,6 +274,7 @@ s3_callback_params: Optional[Dict] = None generic_logger_headers: Optional[Dict] = None default_key_generate_params: Optional[Dict] = None upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None +key_generation_settings: Optional[StandardKeyGenerationConfig] = None default_internal_user_params: Optional[Dict] = None default_team_settings: Optional[List] = None max_user_budget: Optional[float] = None diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index f12226736..dd4c06576 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -15,4 +15,8 @@ model_list: litellm_settings: success_callback: ["langfuse"] callbacks: ["prometheus"] - # disable_end_user_cost_tracking: true + key_generation_settings: + team_key_generation: + allowed_team_member_roles: ["admin"] + personal_key_generation: # maps to 'Default Team' on UI + allowed_user_roles: ["proxy_admin"] \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index b0d272f26..9e05e4cff 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -892,10 +892,6 @@ class DeleteCustomerRequest(LiteLLMBase): class Member(LiteLLMBase): role: Literal[ - LitellmUserRoles.ORG_ADMIN, - LitellmUserRoles.INTERNAL_USER, - LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, - # older Member roles "admin", "user", ] diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index e4493a28c..ab13616d5 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -40,6 +40,77 @@ from litellm.proxy.utils import ( ) from litellm.secret_managers.main import get_secret + +def _is_team_key(data: GenerateKeyRequest): + return data.team_id is not None + + +def _team_key_generation_check(user_api_key_dict: UserAPIKeyAuth): + if ( + litellm.key_generation_settings is None + or litellm.key_generation_settings.get("team_key_generation") is None + ): + return True + + if user_api_key_dict.team_member is None: + raise HTTPException( + status_code=400, + detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}", + ) + + team_member_role = user_api_key_dict.team_member.role + if ( + team_member_role + not in litellm.key_generation_settings["team_key_generation"][ # type: ignore + "allowed_team_member_roles" + ] + ): + raise HTTPException( + status_code=400, + detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore + ) + return True + + +def _personal_key_generation_check(user_api_key_dict: UserAPIKeyAuth): + + if ( + litellm.key_generation_settings is None + or litellm.key_generation_settings.get("personal_key_generation") is None + ): + return True + + if ( + user_api_key_dict.user_role + not in litellm.key_generation_settings["personal_key_generation"][ # type: ignore + "allowed_user_roles" + ] + ): + raise HTTPException( + status_code=400, + detail=f"Personal key creation has been restricted by admin. Allowed roles={litellm.key_generation_settings['personal_key_generation']['allowed_user_roles']}. Your role={user_api_key_dict.user_role}", # type: ignore + ) + return True + + +def key_generation_check( + user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest +) -> bool: + """ + Check if admin has restricted key creation to certain roles for teams or individuals + """ + if litellm.key_generation_settings is None: + return True + + ## check if key is for team or individual + is_team_key = _is_team_key(data=data) + + if is_team_key: + return _team_key_generation_check(user_api_key_dict) + else: + return _personal_key_generation_check(user_api_key_dict=user_api_key_dict) + + router = APIRouter() @@ -131,6 +202,8 @@ async def generate_key_fn( # noqa: PLR0915 raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=message ) + elif litellm.key_generation_settings is not None: + key_generation_check(user_api_key_dict=user_api_key_dict, data=data) # check if user set default key/generate params on config.yaml if litellm.default_key_generate_params is not None: for elem in data: diff --git a/litellm/types/utils.py b/litellm/types/utils.py index d02129681..334894320 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1602,3 +1602,16 @@ class StandardCallbackDynamicParams(TypedDict, total=False): langsmith_api_key: Optional[str] langsmith_project: Optional[str] langsmith_base_url: Optional[str] + + +class TeamUIKeyGenerationConfig(TypedDict): + allowed_team_member_roles: List[str] + + +class PersonalUIKeyGenerationConfig(TypedDict): + allowed_user_roles: List[str] + + +class StandardKeyGenerationConfig(TypedDict, total=False): + team_key_generation: TeamUIKeyGenerationConfig + personal_key_generation: PersonalUIKeyGenerationConfig From 463fa0c9d51baa1758e6bb9655f3c4fa26bf7a90 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 22 Nov 2024 23:03:38 +0530 Subject: [PATCH 05/20] test(test_key_management.py): add unit testing for personal / team key restriction checks --- .../test_key_management.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index b039a101b..81d9fb676 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -542,3 +542,65 @@ async def test_list_teams(prisma_client): # Clean up await prisma_client.delete_data(team_id_list=[team_id], table_name="team") + + +def test_is_team_key(): + from litellm.proxy.management_endpoints.key_management_endpoints import _is_team_key + + assert _is_team_key(GenerateKeyRequest(team_id="test_team_id")) + assert not _is_team_key(GenerateKeyRequest(user_id="test_user_id")) + + +def test_team_key_generation_check(): + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _team_key_generation_check, + ) + from fastapi import HTTPException + + litellm.key_generation_settings = { + "team_key_generation": {"allowed_team_member_roles": ["admin"]} + } + + assert _team_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-1234", + team_member=Member(role="admin", user_id="test_user_id"), + ) + ) + + with pytest.raises(HTTPException): + _team_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-1234", + user_id="test_user_id", + team_member=Member(role="user", user_id="test_user_id"), + ) + ) + + +def test_personal_key_generation_check(): + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _personal_key_generation_check, + ) + from fastapi import HTTPException + + litellm.key_generation_settings = { + "personal_key_generation": {"allowed_user_roles": ["proxy_admin"]} + } + + assert _personal_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin" + ) + ) + + with pytest.raises(HTTPException): + _personal_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-1234", + user_id="admin", + ) + ) From eb0a357eda72812b9dc03274fca15d78240b483b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 22 Nov 2024 23:11:58 +0530 Subject: [PATCH 06/20] docs: add docs on restricting key creation --- docs/my-website/docs/proxy/configs.md | 1 + docs/my-website/docs/proxy/self_serve.md | 8 ++- docs/my-website/docs/proxy/virtual_keys.md | 69 ++++++++++++++++++++++ 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 6f54df2ae..df22a29e3 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -755,6 +755,7 @@ general_settings: | cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) | | cache_params.mode | string | The mode of the cache. [Further docs](./caching) | | disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. | +| key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) | ### general_settings - Reference diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index e04aa4b44..494d9e60d 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -217,4 +217,10 @@ litellm_settings: max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None. tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None. rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None. -``` \ No newline at end of file + + key_generation_settings: # Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) + team_key_generation: + allowed_team_member_roles: ["admin"] + personal_key_generation: # maps to 'Default Team' on UI + allowed_user_roles: ["proxy_admin"] +``` diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index 3b9a2a03e..98b06d33b 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -811,6 +811,75 @@ litellm_settings: team_id: "core-infra" ``` +### Restricting Key Generation + +Use this to control who can generate keys. Useful when letting others create keys on the UI. + +```yaml +litellm_settings: + key_generation_settings: + team_key_generation: + allowed_team_member_roles: ["admin"] + personal_key_generation: # maps to 'Default Team' on UI + allowed_user_roles: ["proxy_admin"] +``` + +#### Spec + +```python +class TeamUIKeyGenerationConfig(TypedDict): + allowed_team_member_roles: List[str] + + +class PersonalUIKeyGenerationConfig(TypedDict): + allowed_user_roles: List[LitellmUserRoles] + + +class StandardKeyGenerationConfig(TypedDict, total=False): + team_key_generation: TeamUIKeyGenerationConfig + personal_key_generation: PersonalUIKeyGenerationConfig + + +class LitellmUserRoles(str, enum.Enum): + """ + Admin Roles: + PROXY_ADMIN: admin over the platform + PROXY_ADMIN_VIEW_ONLY: can login, view all own keys, view all spend + ORG_ADMIN: admin over a specific organization, can create teams, users only within their organization + + Internal User Roles: + INTERNAL_USER: can login, view/create/delete their own keys, view their spend + INTERNAL_USER_VIEW_ONLY: can login, view their own keys, view their own spend + + + Team Roles: + TEAM: used for JWT auth + + + Customer Roles: + CUSTOMER: External users -> these are customers + + """ + + # Admin Roles + PROXY_ADMIN = "proxy_admin" + PROXY_ADMIN_VIEW_ONLY = "proxy_admin_viewer" + + # Organization admins + ORG_ADMIN = "org_admin" + + # Internal User Roles + INTERNAL_USER = "internal_user" + INTERNAL_USER_VIEW_ONLY = "internal_user_viewer" + + # Team Roles + TEAM = "team" + + # Customer Roles - External users of proxy + CUSTOMER = "customer" +``` + + ## **Next Steps - Set Budgets, Rate Limits per Virtual Key** [Follow this doc to set budgets, rate limiters per virtual key with LiteLLM](users) From 4beb48829cb9941e49ac2c52b93c576a7fc82152 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 00:08:03 +0530 Subject: [PATCH 07/20] docs(finetuned_models.md): add new guide on calling finetuned models --- .../docs/guides/finetuned_models.md | 74 +++++++++++++++++++ docs/my-website/sidebars.js | 2 + 2 files changed, 76 insertions(+) create mode 100644 docs/my-website/docs/guides/finetuned_models.md diff --git a/docs/my-website/docs/guides/finetuned_models.md b/docs/my-website/docs/guides/finetuned_models.md new file mode 100644 index 000000000..cb0d49b44 --- /dev/null +++ b/docs/my-website/docs/guides/finetuned_models.md @@ -0,0 +1,74 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + + +# Calling Finetuned Models + +## OpenAI + + +| Model Name | Function Call | +|---------------------------|-----------------------------------------------------------------| +| fine tuned `gpt-4-0613` | `response = completion(model="ft:gpt-4-0613", messages=messages)` | +| fine tuned `gpt-4o-2024-05-13` | `response = completion(model="ft:gpt-4o-2024-05-13", messages=messages)` | +| fine tuned `gpt-3.5-turbo-0125` | `response = completion(model="ft:gpt-3.5-turbo-0125", messages=messages)` | +| fine tuned `gpt-3.5-turbo-1106` | `response = completion(model="ft:gpt-3.5-turbo-1106", messages=messages)` | +| fine tuned `gpt-3.5-turbo-0613` | `response = completion(model="ft:gpt-3.5-turbo-0613", messages=messages)` | + + +## Vertex AI + +Fine tuned models on vertex have a numerical model/endpoint id. + + + + +```python +from litellm import completion +import os + +## set ENV variables +os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811" +os.environ["VERTEXAI_LOCATION"] = "us-central1" + +response = completion( + model="vertex_ai/", # e.g. vertex_ai/4965075652664360960 + messages=[{ "content": "Hello, how are you?","role": "user"}], + base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing +) +``` + + + + +1. Add Vertex Credentials to your env + +```bash +!gcloud auth application-default login +``` + +2. Setup config.yaml + +```yaml +- model_name: finetuned-gemini + litellm_params: + model: vertex_ai/ + vertex_project: + vertex_location: + model_info: + base_model: vertex_ai/gemini-1.5-pro # IMPORTANT +``` + +3. Test it! + +```bash +curl --location 'https://0.0.0.0:4000/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: ' \ +--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}' +``` + + + + + diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index a18122aa6..f2bb1c5e9 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -205,6 +205,7 @@ const sidebars = { items: [ "exception_mapping", "completion/provider_specific_params", + "guides/finetuned_models", "completion/audio", "completion/vision", "completion/json_mode", @@ -220,6 +221,7 @@ const sidebars = { "completion/batching", "completion/mock_requests", "completion/reliable_completions", + ] }, { From d788c3c37fd646e8ffd563d520eb67e8b4eca428 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 00:21:04 +0530 Subject: [PATCH 08/20] docs(input.md): cleanup anthropic supported params Closes https://github.com/BerriAI/litellm/issues/6856 --- docs/my-website/docs/completion/input.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index c563a5bf0..e55c160e0 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -41,7 +41,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea | Provider | temperature | max_completion_tokens | max_tokens | top_p | stream | stream_options | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers | |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| -|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | +|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | | ✅ | ✅ | | | ✅ | |OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ | |Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ | |Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | From 1a3fb18a6499701dd3793fc55962961be2310e53 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 00:32:40 +0530 Subject: [PATCH 09/20] test(test_embedding.py): add test for passing extra headers via embedding --- tests/local_testing/test_embedding.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index d7988e690..97707a234 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -1080,3 +1080,20 @@ def test_cohere_img_embeddings(input, input_type): assert response.usage.prompt_tokens_details.image_tokens > 0 else: assert response.usage.prompt_tokens_details.text_tokens > 0 + + +def test_embedding_with_extra_headers(): + input = ["hello world"] + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + + with patch.object(client, "post") as mock_post: + embedding( + model="cohere/embed-english-v3.0", + input=input, + extra_headers={"my-test-param": "hello-world"}, + client=client, + ) + + assert "my-test-param" in mock_post.call_args.kwargs["headers"] From 94fe1355243943f96043dd510afc7498dc0707d0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 00:47:26 +0530 Subject: [PATCH 10/20] feat(cohere/embed): pass client to async embedding --- litellm/llms/cohere/embed/handler.py | 6 ++++++ tests/local_testing/test_embedding.py | 31 +++++++++++++++++++-------- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/litellm/llms/cohere/embed/handler.py b/litellm/llms/cohere/embed/handler.py index 5b224c375..afeba10b5 100644 --- a/litellm/llms/cohere/embed/handler.py +++ b/litellm/llms/cohere/embed/handler.py @@ -74,6 +74,7 @@ async def async_embedding( }, ) ## COMPLETION CALL + if client is None: client = get_async_httpx_client( llm_provider=litellm.LlmProviders.COHERE, @@ -151,6 +152,11 @@ def embedding( api_key=api_key, headers=headers, encoding=encoding, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), ) ## LOGGING diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index 97707a234..23d712b00 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -1082,18 +1082,31 @@ def test_cohere_img_embeddings(input, input_type): assert response.usage.prompt_tokens_details.text_tokens > 0 -def test_embedding_with_extra_headers(): +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_embedding_with_extra_headers(sync_mode): input = ["hello world"] - from litellm.llms.custom_httpx.http_handler import HTTPHandler + from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler - client = HTTPHandler() + if sync_mode: + client = HTTPHandler() + else: + client = AsyncHTTPHandler() + data = { + "model": "cohere/embed-english-v3.0", + "input": input, + "extra_headers": {"my-test-param": "hello-world"}, + "client": client, + } with patch.object(client, "post") as mock_post: - embedding( - model="cohere/embed-english-v3.0", - input=input, - extra_headers={"my-test-param": "hello-world"}, - client=client, - ) + try: + if sync_mode: + embedding(**data) + else: + await litellm.aembedding(**data) + except Exception as e: + print(e) + mock_post.assert_called_once() assert "my-test-param" in mock_post.call_args.kwargs["headers"] From 250d66b3352e5a950b5a33c53d6d2cccdafc98e0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 01:07:39 +0530 Subject: [PATCH 11/20] feat(rerank.py): add `/v1/rerank` if missing for cohere base url Closes https://github.com/BerriAI/litellm/issues/6844 --- litellm/llms/cohere/rerank.py | 37 ++++++++++++++++++++++++++---- litellm/rerank_api/main.py | 4 +++- tests/local_testing/test_rerank.py | 29 +++++++++++++++++++++++ 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py index 022ffc6f9..8de2dfbb4 100644 --- a/litellm/llms/cohere/rerank.py +++ b/litellm/llms/cohere/rerank.py @@ -6,10 +6,14 @@ LiteLLM supports the re rank API format, no paramter transformation occurs from typing import Any, Dict, List, Optional, Union +import httpx + import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, _get_httpx_client, get_async_httpx_client, ) @@ -34,6 +38,23 @@ class CohereRerank(BaseLLM): # Merge other headers, overriding any default ones except Authorization return {**default_headers, **headers} + def ensure_rerank_endpoint(self, api_base: str) -> str: + """ + Ensures the `/v1/rerank` endpoint is appended to the given `api_base`. + If `/v1/rerank` is already present, the original URL is returned. + + :param api_base: The base API URL. + :return: A URL with `/v1/rerank` appended if missing. + """ + # Parse the base URL to ensure proper structure + url = httpx.URL(api_base) + + # Check if the URL already ends with `/v1/rerank` + if not url.path.endswith("/v1/rerank"): + url = url.copy_with(path=f"{url.path.rstrip('/')}/v1/rerank") + + return str(url) + def rerank( self, model: str, @@ -48,9 +69,10 @@ class CohereRerank(BaseLLM): return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, _is_async: Optional[bool] = False, # New parameter + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: headers = self.validate_environment(api_key=api_key, headers=headers) - + api_base = self.ensure_rerank_endpoint(api_base) request_data = RerankRequest( model=model, query=query, @@ -76,9 +98,13 @@ class CohereRerank(BaseLLM): if _is_async: return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method - client = _get_httpx_client() + if client is not None and isinstance(client, HTTPHandler): + client = client + else: + client = _get_httpx_client() + response = client.post( - api_base, + url=api_base, headers=headers, json=request_data_dict, ) @@ -100,10 +126,13 @@ class CohereRerank(BaseLLM): api_key: str, api_base: str, headers: dict, + client: Optional[AsyncHTTPHandler] = None, ) -> RerankResponse: request_data_dict = request_data.dict(exclude_none=True) - client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE) + client = client or get_async_httpx_client( + llm_provider=litellm.LlmProviders.COHERE + ) response = await client.post( api_base, diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 9cc8a8c1d..7e6dc7503 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -91,6 +91,7 @@ def rerank( model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", {}) user = kwargs.get("user", None) + client = kwargs.get("client", None) try: _is_async = kwargs.pop("arerank", False) is True optional_params = GenericLiteLLMParams(**kwargs) @@ -150,7 +151,7 @@ def rerank( or optional_params.api_base or litellm.api_base or get_secret("COHERE_API_BASE") # type: ignore - or "https://api.cohere.com/v1/rerank" + or "https://api.cohere.com" ) if api_base is None: @@ -173,6 +174,7 @@ def rerank( _is_async=_is_async, headers=headers, litellm_logging_obj=litellm_logging_obj, + client=client, ) elif _custom_llm_provider == "azure_ai": api_base = ( diff --git a/tests/local_testing/test_rerank.py b/tests/local_testing/test_rerank.py index c5ed1efe5..f9160991c 100644 --- a/tests/local_testing/test_rerank.py +++ b/tests/local_testing/test_rerank.py @@ -258,3 +258,32 @@ async def test_rerank_custom_callbacks(): assert custom_logger.kwargs.get("response_cost") > 0.0 assert custom_logger.response_obj is not None assert custom_logger.response_obj.results is not None + + +def test_complete_base_url_cohere(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + litellm.api_base = "http://localhost:4000" + litellm.set_verbose = True + + text = "Hello there!" + list_texts = ["Hello there!", "How are you?", "How do you do?"] + + rerank_model = "rerank-multilingual-v3.0" + + with patch.object(client, "post") as mock_post: + try: + litellm.rerank( + model=rerank_model, + query=text, + documents=list_texts, + custom_llm_provider="cohere", + client=client, + ) + except Exception as e: + print(e) + + print("mock_post.call_args", mock_post.call_args) + mock_post.assert_called_once() + assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"] From dfb34dfe921bd4e259b85d6099bee4bf39c6ff6b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 01:23:03 +0530 Subject: [PATCH 12/20] fix(main.py): pass extra_headers param to openai Fixes https://github.com/BerriAI/litellm/issues/6836 --- litellm/main.py | 4 ++++ tests/local_testing/test_embedding.py | 1 + 2 files changed, 5 insertions(+) diff --git a/litellm/main.py b/litellm/main.py index 5d433eb36..5095ce518 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3440,6 +3440,10 @@ def embedding( # noqa: PLR0915 or litellm.openai_key or get_secret_str("OPENAI_API_KEY") ) + + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers + api_type = "openai" api_version = None diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index 23d712b00..096dfc419 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -1085,6 +1085,7 @@ def test_cohere_img_embeddings(input, input_type): @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_embedding_with_extra_headers(sync_mode): + input = ["hello world"] from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler From 541326731fb930889b0afed2d7541d9a7efe7e46 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 02:00:45 +0530 Subject: [PATCH 13/20] fix(litellm_logging.py): don't disable global callbacks when dynamic callbacks are set Fixes issue where global callbacks - e.g. prometheus were overriden when langfuse was set dynamically --- litellm/litellm_core_utils/litellm_logging.py | 82 ++++++------------- .../test_unit_tests_init_callbacks.py | 72 ++++++++++++++++ 2 files changed, 96 insertions(+), 58 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index a9ded458d..298e28974 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -934,19 +934,10 @@ class Logging: status="success", ) ) - if self.dynamic_success_callbacks is not None and isinstance( - self.dynamic_success_callbacks, list - ): - callbacks = self.dynamic_success_callbacks - ## keep the internal functions ## - for callback in litellm.success_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm.success_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_success_callbacks, + global_callbacks=litellm.success_callback, + ) ## REDACT MESSAGES ## result = redact_message_input_output_from_logging( @@ -1469,21 +1460,10 @@ class Logging: status="success", ) ) - if self.dynamic_async_success_callbacks is not None and isinstance( - self.dynamic_async_success_callbacks, list - ): - callbacks = self.dynamic_async_success_callbacks - ## keep the internal functions ## - for callback in litellm._async_success_callback: - callback_name = "" - if isinstance(callback, CustomLogger): - callback_name = callback.__class__.__name__ - if callable(callback): - callback_name = callback.__name__ - if "_PROXY_" in callback_name: - callbacks.append(callback) - else: - callbacks = litellm._async_success_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_async_success_callbacks, + global_callbacks=litellm._async_success_callback, + ) result = redact_message_input_output_from_logging( model_call_details=( @@ -1750,21 +1730,10 @@ class Logging: start_time=start_time, end_time=end_time, ) - callbacks = [] # init this to empty incase it's not created - - if self.dynamic_failure_callbacks is not None and isinstance( - self.dynamic_failure_callbacks, list - ): - callbacks = self.dynamic_failure_callbacks - ## keep the internal functions ## - for callback in litellm.failure_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm.failure_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_failure_callbacks, + global_callbacks=litellm.failure_callback, + ) result = None # result sent to all loggers, init this to None incase it's not created @@ -1947,21 +1916,10 @@ class Logging: end_time=end_time, ) - callbacks = [] # init this to empty incase it's not created - - if self.dynamic_async_failure_callbacks is not None and isinstance( - self.dynamic_async_failure_callbacks, list - ): - callbacks = self.dynamic_async_failure_callbacks - ## keep the internal functions ## - for callback in litellm._async_failure_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm._async_failure_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_async_failure_callbacks, + global_callbacks=litellm._async_failure_callback, + ) result = None # result sent to all loggers, init this to None incase it's not created for callback in callbacks: @@ -2953,3 +2911,11 @@ def modify_integration(integration_name, integration_params): if integration_name == "supabase": if "table_name" in integration_params: Supabase.supabase_table_name = integration_params["table_name"] + + +def get_combined_callback_list( + dynamic_success_callbacks: Optional[List], global_callbacks: List +) -> List: + if dynamic_success_callbacks is None: + return global_callbacks + return list(set(dynamic_success_callbacks + global_callbacks)) diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index 38883fa38..2c373772a 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -216,3 +216,75 @@ async def test_init_custom_logger_compatible_class_as_callback(): await use_callback_in_llm_call(callback, used_in="success_callback") reset_env_vars() + + +def test_dynamic_logging_global_callback(): + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.integrations.custom_logger import CustomLogger + from litellm.types.utils import ModelResponse, Choices, Message, Usage + + cl = CustomLogger() + + litellm_logging = LiteLLMLoggingObj( + model="claude-3-opus-20240229", + messages=[{"role": "user", "content": "hi"}], + stream=False, + call_type="completion", + start_time=datetime.now(), + litellm_call_id="123", + function_id="456", + kwargs={ + "langfuse_public_key": "my-mock-public-key", + "langfuse_secret_key": "my-mock-secret-key", + }, + dynamic_success_callbacks=["langfuse"], + ) + + with patch.object(cl, "log_success_event") as mock_log_success_event: + cl.log_success_event = mock_log_success_event + litellm.success_callback = [cl] + + try: + litellm_logging.success_handler( + result=ModelResponse( + id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70", + created=1732306261, + model="claude-3-opus-20240229", + object="chat.completion", + system_fingerprint=None, + choices=[ + Choices( + finish_reason="stop", + index=0, + message=Message( + content="hello", + role="assistant", + tool_calls=None, + function_call=None, + ), + ) + ], + usage=Usage( + completion_tokens=20, + prompt_tokens=10, + total_tokens=30, + completion_tokens_details=None, + prompt_tokens_details=None, + ), + ), + start_time=datetime.now(), + end_time=datetime.now(), + cache_hit=False, + ) + except Exception as e: + print(f"Error: {e}") + + mock_log_success_event.assert_called_once() + + +def test_get_combined_callback_list(): + from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list + + assert get_combined_callback_list( + dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] + ) == ["langfuse", "lago"] From 31943bf2ad1926d528d0581bf6e2431ee42fc9de Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 02:08:50 +0530 Subject: [PATCH 14/20] fix(handler.py): fix linting error --- litellm/llms/azure_ai/rerank/handler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/llms/azure_ai/rerank/handler.py b/litellm/llms/azure_ai/rerank/handler.py index a67c893f2..60edfd296 100644 --- a/litellm/llms/azure_ai/rerank/handler.py +++ b/litellm/llms/azure_ai/rerank/handler.py @@ -4,6 +4,7 @@ import httpx from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.cohere.rerank import CohereRerank +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.rerank import RerankResponse @@ -73,6 +74,7 @@ class AzureAIRerank(CohereRerank): return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, _is_async: Optional[bool] = False, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: if headers is None: From f8a46b595073b3816006bfffd16debc596a228e7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 02:36:09 +0530 Subject: [PATCH 15/20] fix: fix typing --- litellm/proxy/_types.py | 44 ++++++++++++++++--- .../test_role_based_access.py | 10 ++--- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 9e05e4cff..74e82b0ea 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2,6 +2,7 @@ import enum import json import os import sys +import traceback import uuid from dataclasses import fields from datetime import datetime @@ -890,11 +891,7 @@ class DeleteCustomerRequest(LiteLLMBase): user_ids: List[str] -class Member(LiteLLMBase): - role: Literal[ - "admin", - "user", - ] +class MemberBase(LiteLLMBase): user_id: Optional[str] = None user_email: Optional[str] = None @@ -908,6 +905,21 @@ class Member(LiteLLMBase): return values +class Member(MemberBase): + role: Literal[ + "admin", + "user", + ] + + +class OrgMember(MemberBase): + role: Literal[ + LitellmUserRoles.ORG_ADMIN, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + ] + + class TeamBase(LiteLLMBase): team_alias: Optional[str] = None team_id: Optional[str] = None @@ -1970,6 +1982,26 @@ class MemberAddRequest(LiteLLMBase): # Replace member_data with the single Member object data["member"] = member # Call the superclass __init__ method to initialize the object + traceback.print_stack() + super().__init__(**data) + + +class OrgMemberAddRequest(LiteLLMBase): + member: Union[List[OrgMember], OrgMember] + + def __init__(self, **data): + member_data = data.get("member") + if isinstance(member_data, list): + # If member is a list of dictionaries, convert each dictionary to a Member object + members = [OrgMember(**item) for item in member_data] + # Replace member_data with the list of Member objects + data["member"] = members + elif isinstance(member_data, dict): + # If member is a dictionary, convert it to a single Member object + member = OrgMember(**member_data) + # Replace member_data with the single Member object + data["member"] = member + # Call the superclass __init__ method to initialize the object super().__init__(**data) @@ -2021,7 +2053,7 @@ class TeamMemberUpdateResponse(MemberUpdateResponse): # Organization Member Requests -class OrganizationMemberAddRequest(MemberAddRequest): +class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str max_budget_in_organization: Optional[float] = ( None # Users max budget within the organization diff --git a/tests/proxy_admin_ui_tests/test_role_based_access.py b/tests/proxy_admin_ui_tests/test_role_based_access.py index 609a3598d..ff73143bf 100644 --- a/tests/proxy_admin_ui_tests/test_role_based_access.py +++ b/tests/proxy_admin_ui_tests/test_role_based_access.py @@ -160,7 +160,7 @@ async def test_create_new_user_in_organization(prisma_client, user_role): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member(role=user_role, user_id=created_user_id), + member=OrgMember(role=user_role, user_id=created_user_id), ), http_request=None, ) @@ -220,7 +220,7 @@ async def test_org_admin_create_team_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), ), http_request=None, ) @@ -292,7 +292,7 @@ async def test_org_admin_create_user_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), ), http_request=None, ) @@ -323,7 +323,7 @@ async def test_org_admin_create_user_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member( + member=OrgMember( role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org ), ), @@ -375,7 +375,7 @@ async def test_org_admin_create_user_team_wrong_org_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org1_id, - member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), ), http_request=None, ) From e2ccf6b7f452c8997ae9e418639943d2d1bfff0b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 02:44:10 +0530 Subject: [PATCH 16/20] build: add conftest to proxy_admin_ui_tests/ --- litellm/proxy/_new_secret_config.yaml | 5 --- tests/proxy_admin_ui_tests/conftest.py | 54 ++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 tests/proxy_admin_ui_tests/conftest.py diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index dd4c06576..47204d9c8 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -15,8 +15,3 @@ model_list: litellm_settings: success_callback: ["langfuse"] callbacks: ["prometheus"] - key_generation_settings: - team_key_generation: - allowed_team_member_roles: ["admin"] - personal_key_generation: # maps to 'Default Team' on UI - allowed_user_roles: ["proxy_admin"] \ No newline at end of file diff --git a/tests/proxy_admin_ui_tests/conftest.py b/tests/proxy_admin_ui_tests/conftest.py new file mode 100644 index 000000000..eca0bc431 --- /dev/null +++ b/tests/proxy_admin_ui_tests/conftest.py @@ -0,0 +1,54 @@ +# conftest.py + +import importlib +import os +import sys + +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm + + +@pytest.fixture(scope="function", autouse=True) +def setup_and_teardown(): + """ + This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained. + """ + curr_dir = os.getcwd() # Get the current working directory + sys.path.insert( + 0, os.path.abspath("../..") + ) # Adds the project directory to the system path + + import litellm + from litellm import Router + + importlib.reload(litellm) + import asyncio + + loop = asyncio.get_event_loop_policy().new_event_loop() + asyncio.set_event_loop(loop) + print(litellm) + # from litellm import Router, completion, aembedding, acompletion, embedding + yield + + # Teardown code (executes after the yield point) + loop.close() # Close the loop created earlier + asyncio.set_event_loop(None) # Remove the reference to the loop + + +def pytest_collection_modifyitems(config, items): + # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests + custom_logger_tests = [ + item for item in items if "custom_logger" in item.parent.name + ] + other_tests = [item for item in items if "custom_logger" not in item.parent.name] + + # Sort tests based on their names + custom_logger_tests.sort(key=lambda x: x.name) + other_tests.sort(key=lambda x: x.name) + + # Reorder the items list + items[:] = custom_logger_tests + other_tests From ed64dd7f9d43a4c54ad91f9450d46738fe3401fc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 03:06:51 +0530 Subject: [PATCH 17/20] test: fix test --- .../test_unit_tests_init_callbacks.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index 2c373772a..15c2118d8 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -285,6 +285,9 @@ def test_dynamic_logging_global_callback(): def test_get_combined_callback_list(): from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list - assert get_combined_callback_list( + assert "langfuse" in get_combined_callback_list( dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] - ) == ["langfuse", "lago"] + ) + assert "lago" in get_combined_callback_list( + dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] + ) From b5c5f87e25105ee8521ccc8f1c3eccfff1b0f14e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 03:12:59 +0530 Subject: [PATCH 18/20] fix: fix linting errors --- litellm/proxy/management_endpoints/organization_endpoints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/management_endpoints/organization_endpoints.py b/litellm/proxy/management_endpoints/organization_endpoints.py index 81d135097..363384375 100644 --- a/litellm/proxy/management_endpoints/organization_endpoints.py +++ b/litellm/proxy/management_endpoints/organization_endpoints.py @@ -352,7 +352,7 @@ async def organization_member_add( }, ) - members: List[Member] + members: List[OrgMember] if isinstance(data.member, List): members = data.member else: @@ -397,7 +397,7 @@ async def organization_member_add( async def add_member_to_organization( - member: Member, + member: OrgMember, organization_id: str, prisma_client: PrismaClient, ) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]: From 88db948d29683dba0fc324fa4f7470f8967cc398 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 12:35:55 +0530 Subject: [PATCH 19/20] test: fix test --- tests/local_testing/test_rerank.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/local_testing/test_rerank.py b/tests/local_testing/test_rerank.py index f9160991c..5fca6f135 100644 --- a/tests/local_testing/test_rerank.py +++ b/tests/local_testing/test_rerank.py @@ -215,7 +215,10 @@ async def test_rerank_custom_api_base(): args_to_api = kwargs["json"] print("Arguments passed to API=", args_to_api) print("url = ", _url) - assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/" + assert ( + _url[0] + == "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank" + ) assert args_to_api == expected_payload assert response.id is not None assert response.results is not None From b06b0248ffb69c1121755369bf7551b06b1bf9cc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 14:11:54 +0530 Subject: [PATCH 20/20] fix: fix pass through testing --- .../vertex_passthrough_logging_handler.py | 5 ++++ .../test_unit_test_anthropic_pass_through.py | 27 ++----------------- .../test_unit_test_streaming.py | 1 + 3 files changed, 8 insertions(+), 25 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index 75a0d04ec..2773979ad 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -153,6 +153,11 @@ class VertexPassthroughLoggingHandler: verbose_proxy_logger.error( "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." ) + return { + "result": None, + "kwargs": kwargs, + } + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( litellm_model_response=complete_streaming_response, model=model, diff --git a/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py index afb77f718..ecd289005 100644 --- a/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py +++ b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py @@ -73,7 +73,7 @@ async def test_anthropic_passthrough_handler( start_time = datetime.now() end_time = datetime.now() - await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( + result = AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( httpx_response=mock_httpx_response, response_body=mock_response, logging_obj=mock_logging_obj, @@ -84,30 +84,7 @@ async def test_anthropic_passthrough_handler( cache_hit=False, ) - # Assert that async_success_handler was called - assert mock_logging_obj.async_success_handler.called - - call_args = mock_logging_obj.async_success_handler.call_args - call_kwargs = call_args.kwargs - print("call_kwargs", call_kwargs) - - # Assert required fields are present in call_kwargs - assert "result" in call_kwargs - assert "start_time" in call_kwargs - assert "end_time" in call_kwargs - assert "cache_hit" in call_kwargs - assert "response_cost" in call_kwargs - assert "model" in call_kwargs - assert "standard_logging_object" in call_kwargs - - # Assert specific values and types - assert isinstance(call_kwargs["result"], litellm.ModelResponse) - assert isinstance(call_kwargs["start_time"], datetime) - assert isinstance(call_kwargs["end_time"], datetime) - assert isinstance(call_kwargs["cache_hit"], bool) - assert isinstance(call_kwargs["response_cost"], float) - assert call_kwargs["model"] == "claude-3-opus-20240229" - assert isinstance(call_kwargs["standard_logging_object"], dict) + assert isinstance(result["result"], litellm.ModelResponse) def test_create_anthropic_response_logging_payload(mock_logging_obj): diff --git a/tests/pass_through_unit_tests/test_unit_test_streaming.py b/tests/pass_through_unit_tests/test_unit_test_streaming.py index bbbc465fc..61b71b56d 100644 --- a/tests/pass_through_unit_tests/test_unit_test_streaming.py +++ b/tests/pass_through_unit_tests/test_unit_test_streaming.py @@ -64,6 +64,7 @@ async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route): litellm_logging_obj = MagicMock() start_time = datetime.now() passthrough_success_handler_obj = MagicMock() + litellm_logging_obj.async_success_handler = AsyncMock() # Capture yielded chunks and perform detailed assertions received_chunks = []