From 372323c38a5ba849b51837f2eb80e89d74f2a696 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 20 May 2024 10:30:23 -0700 Subject: [PATCH 1/4] feat(proxy_server.py): allow admin to return rejected response as string to user Closes https://github.com/BerriAI/litellm/issues/3671 --- litellm/integrations/custom_logger.py | 15 +++- litellm/proxy/proxy_server.py | 35 ++++++++- litellm/proxy/utils.py | 109 +++++++++++++++++++++++++- litellm/utils.py | 26 ++++++ 4 files changed, 175 insertions(+), 10 deletions(-) diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index d508825922..accb4f80f1 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -4,7 +4,7 @@ import dotenv, os from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache - +from litellm.utils import ModelResponse from typing import Literal, Union, Optional import traceback @@ -64,8 +64,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, - call_type: Literal["completion", "embeddings", "image_generation"], - ): + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ], + ) -> Optional[ + Union[Exception, str, dict] + ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm pass async def async_post_call_failure_hook( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ef5380d089..6ad3e598a5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3761,11 +3761,24 @@ async def chat_completion( data["litellm_logging_obj"] = logging_obj - ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook( + ### CALL HOOKS ### - modify/reject incoming data before calling the model + data = await proxy_logging_obj.pre_call_hook( # type: ignore user_api_key_dict=user_api_key_dict, data=data, call_type="completion" ) + if isinstance(data, litellm.ModelResponse): + return data + elif isinstance(data, litellm.CustomStreamWrapper): + selected_data_generator = select_data_generator( + response=data, + user_api_key_dict=user_api_key_dict, + request_data={}, + ) + + return StreamingResponse( + selected_data_generator, + media_type="text/event-stream", + ) tasks = [] tasks.append( proxy_logging_obj.during_call_hook( @@ -3998,10 +4011,24 @@ async def completion( data["model"] = litellm.model_alias_map[data["model"]] ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook( - user_api_key_dict=user_api_key_dict, data=data, call_type="completion" + data = await proxy_logging_obj.pre_call_hook( # type: ignore + user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" ) + if isinstance(data, litellm.TextCompletionResponse): + return data + elif isinstance(data, litellm.TextCompletionStreamWrapper): + selected_data_generator = select_data_generator( + response=data, + user_api_key_dict=user_api_key_dict, + request_data={}, + ) + + return StreamingResponse( + selected_data_generator, + media_type="text/event-stream", + ) + ### ROUTE THE REQUESTs ### router_model_names = llm_router.model_names if llm_router is not None else [] # skip router if user passed their key diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 09e772e10b..fc49ebc7fe 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -19,7 +19,16 @@ from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler, ) from litellm._service_logger import ServiceLogging, ServiceTypes -from litellm import ModelResponse, EmbeddingResponse, ImageResponse +from litellm import ( + ModelResponse, + EmbeddingResponse, + ImageResponse, + TranscriptionResponse, + TextCompletionResponse, + CustomStreamWrapper, + TextCompletionStreamWrapper, +) +from litellm.utils import ModelResponseIterator from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck @@ -32,6 +41,7 @@ from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart from datetime import datetime, timedelta from litellm.integrations.slack_alerting import SlackAlerting +from typing_extensions import overload def print_verbose(print_statement): @@ -176,18 +186,60 @@ class ProxyLogging: ) litellm.utils.set_callbacks(callback_list=callback_list) + # fmt: off + + @overload + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: dict, + call_type: Literal["completion"] + ) -> Union[dict, ModelResponse, CustomStreamWrapper]: + ... + + @overload + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: dict, + call_type: Literal["text_completion"] + ) -> Union[dict, TextCompletionResponse, TextCompletionStreamWrapper]: + ... + + @overload + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: dict, + call_type: Literal["embeddings", + "image_generation", + "moderation", + "audio_transcription",] + ) -> dict: + ... + + # fmt: on + + # The actual implementation of the function async def pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal[ "completion", + "text_completion", "embeddings", "image_generation", "moderation", "audio_transcription", ], - ): + ) -> Union[ + dict, + ModelResponse, + TextCompletionResponse, + CustomStreamWrapper, + TextCompletionStreamWrapper, + ]: """ Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. @@ -214,7 +266,58 @@ class ProxyLogging: call_type=call_type, ) if response is not None: - data = response + if isinstance(response, Exception): + raise response + elif isinstance(response, dict): + data = response + elif isinstance(response, str): + if call_type == "completion": + _chat_response = ModelResponse() + _chat_response.choices[0].message.content = response + + if ( + data.get("stream", None) is not None + and data["stream"] == True + ): + _iterator = ModelResponseIterator( + model_response=_chat_response + ) + return CustomStreamWrapper( + completion_stream=_iterator, + model=data.get("model", ""), + custom_llm_provider="cached_response", + logging_obj=data.get( + "litellm_logging_obj", None + ), + ) + return _response + elif call_type == "text_completion": + if ( + data.get("stream", None) is not None + and data["stream"] == True + ): + _chat_response = ModelResponse() + _chat_response.choices[0].message.content = response + + if ( + data.get("stream", None) is not None + and data["stream"] == True + ): + _iterator = ModelResponseIterator( + model_response=_chat_response + ) + return TextCompletionStreamWrapper( + completion_stream=_iterator, + model=data.get("model", ""), + ) + else: + _response = TextCompletionResponse() + _response.choices[0].text = response + return _response + else: + raise HTTPException( + status_code=400, detail={"error": response} + ) print_verbose(f"final data being sent to {call_type} call: {data}") return data diff --git a/litellm/utils.py b/litellm/utils.py index ac246fca6a..1e0485755c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -12187,3 +12187,29 @@ def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str: return request_info except: return request_info + + +class ModelResponseIterator: + def __init__(self, model_response): + self.model_response = model_response + self.is_done = False + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + if self.is_done: + raise StopIteration + self.is_done = True + return self.model_response + + # Async iterator + def __aiter__(self): + return self + + async def __anext__(self): + if self.is_done: + raise StopAsyncIteration + self.is_done = True + return self.model_response From f11f207ae6ba3a39a886634b612b06f591b6eaca Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 20 May 2024 11:14:36 -0700 Subject: [PATCH 2/4] feat(proxy_server.py): refactor returning rejected message, to work with error logging log the rejected request as a failed call to langfuse/slack alerting --- litellm/exceptions.py | 26 +++++ litellm/integrations/custom_logger.py | 1 - litellm/proxy/_super_secret_config.yaml | 6 +- litellm/proxy/_types.py | 4 + .../proxy/hooks/prompt_injection_detection.py | 8 ++ litellm/proxy/proxy_server.py | 76 ++++++++++++--- litellm/proxy/utils.py | 97 +++---------------- 7 files changed, 118 insertions(+), 100 deletions(-) diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 5eb66743b5..d189b7ebe2 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -177,6 +177,32 @@ class ContextWindowExceededError(BadRequestError): # type: ignore ) # Call the base class constructor with the parameters it needs +# sub class of bad request error - meant to help us catch guardrails-related errors on proxy. +class RejectedRequestError(BadRequestError): # type: ignore + def __init__( + self, + message, + model, + llm_provider, + request_data: dict, + litellm_debug_info: Optional[str] = None, + ): + self.status_code = 400 + self.message = message + self.model = model + self.llm_provider = llm_provider + self.litellm_debug_info = litellm_debug_info + self.request_data = request_data + request = httpx.Request(method="POST", url="https://api.openai.com/v1") + response = httpx.Response(status_code=500, request=request) + super().__init__( + message=self.message, + model=self.model, # type: ignore + llm_provider=self.llm_provider, # type: ignore + response=response, + ) # Call the base class constructor with the parameters it needs + + class ContentPolicyViolationError(BadRequestError): # type: ignore # Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}} def __init__( diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index accb4f80f1..e192cdaea7 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -4,7 +4,6 @@ import dotenv, os from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache -from litellm.utils import ModelResponse from typing import Literal, Union, Optional import traceback diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 2195a077d3..42b36950b7 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -17,4 +17,8 @@ model_list: api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault router_settings: - enable_pre_call_checks: true \ No newline at end of file + enable_pre_call_checks: true + +litellm_settings: + callbacks: ["detect_prompt_injection"] + diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index b900b623be..492c222fe6 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -251,6 +251,10 @@ class LiteLLMPromptInjectionParams(LiteLLMBase): llm_api_name: Optional[str] = None llm_api_system_prompt: Optional[str] = None llm_api_fail_call_string: Optional[str] = None + reject_as_response: Optional[bool] = Field( + default=False, + description="Return rejected request error message as a string to the user. Default behaviour is to raise an exception.", + ) @root_validator(pre=True) def check_llm_api_params(cls, values): diff --git a/litellm/proxy/hooks/prompt_injection_detection.py b/litellm/proxy/hooks/prompt_injection_detection.py index 896046e943..87cae71a8a 100644 --- a/litellm/proxy/hooks/prompt_injection_detection.py +++ b/litellm/proxy/hooks/prompt_injection_detection.py @@ -146,6 +146,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger): try: assert call_type in [ "completion", + "text_completion", "embeddings", "image_generation", "moderation", @@ -192,6 +193,13 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger): return data except HTTPException as e: + if ( + e.status_code == 400 + and isinstance(e.detail, dict) + and "error" in e.detail + ): + if self.prompt_injection_params.reject_as_response: + return e.detail["error"] raise e except Exception as e: traceback.print_exc() diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6ad3e598a5..6b395e1382 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -124,6 +124,7 @@ from litellm.proxy.auth.auth_checks import ( get_actual_routes, ) from litellm.llms.custom_httpx.httpx_handler import HTTPHandler +from litellm.exceptions import RejectedRequestError try: from litellm._version import version @@ -3766,19 +3767,6 @@ async def chat_completion( user_api_key_dict=user_api_key_dict, data=data, call_type="completion" ) - if isinstance(data, litellm.ModelResponse): - return data - elif isinstance(data, litellm.CustomStreamWrapper): - selected_data_generator = select_data_generator( - response=data, - user_api_key_dict=user_api_key_dict, - request_data={}, - ) - - return StreamingResponse( - selected_data_generator, - media_type="text/event-stream", - ) tasks = [] tasks.append( proxy_logging_obj.during_call_hook( @@ -3893,6 +3881,40 @@ async def chat_completion( ) return response + except RejectedRequestError as e: + _data = e.request_data + _data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=e, + request_data=_data, + ) + _chat_response = litellm.ModelResponse() + _chat_response.choices[0].message.content = e.message # type: ignore + + if data.get("stream", None) is not None and data["stream"] == True: + _iterator = litellm.utils.ModelResponseIterator( + model_response=_chat_response + ) + _streaming_response = litellm.CustomStreamWrapper( + completion_stream=_iterator, + model=data.get("model", ""), + custom_llm_provider="cached_response", + logging_obj=data.get("litellm_logging_obj", None), + ) + selected_data_generator = select_data_generator( + response=e.message, + user_api_key_dict=user_api_key_dict, + request_data=_data, + ) + + return StreamingResponse( + selected_data_generator, + media_type="text/event-stream", + ) + _usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + _chat_response.usage = _usage # type: ignore + return _chat_response except Exception as e: data["litellm_status"] = "fail" # used for alerting traceback.print_exc() @@ -4112,6 +4134,34 @@ async def completion( ) return response + except RejectedRequestError as e: + _data = e.request_data + _data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=e, + request_data=_data, + ) + if _data.get("stream", None) is not None and _data["stream"] == True: + _chat_response = litellm.ModelResponse() + _usage = litellm.Usage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + _chat_response.usage = _usage # type: ignore + _chat_response.choices[0].message.content = e.message # type: ignore + _iterator = litellm.utils.ModelResponseIterator( + model_response=_chat_response + ) + return litellm.TextCompletionStreamWrapper( + completion_stream=_iterator, + model=_data.get("model", ""), + ) + else: + _response = litellm.TextCompletionResponse() + _response.choices[0].text = e.message + return _response except Exception as e: data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index fc49ebc7fe..586b4c4cda 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler, ) +from litellm.exceptions import RejectedRequestError from litellm._service_logger import ServiceLogging, ServiceTypes from litellm import ( ModelResponse, @@ -186,40 +187,6 @@ class ProxyLogging: ) litellm.utils.set_callbacks(callback_list=callback_list) - # fmt: off - - @overload - async def pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - data: dict, - call_type: Literal["completion"] - ) -> Union[dict, ModelResponse, CustomStreamWrapper]: - ... - - @overload - async def pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - data: dict, - call_type: Literal["text_completion"] - ) -> Union[dict, TextCompletionResponse, TextCompletionStreamWrapper]: - ... - - @overload - async def pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - data: dict, - call_type: Literal["embeddings", - "image_generation", - "moderation", - "audio_transcription",] - ) -> dict: - ... - - # fmt: on - # The actual implementation of the function async def pre_call_hook( self, @@ -233,13 +200,7 @@ class ProxyLogging: "moderation", "audio_transcription", ], - ) -> Union[ - dict, - ModelResponse, - TextCompletionResponse, - CustomStreamWrapper, - TextCompletionStreamWrapper, - ]: + ) -> dict: """ Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. @@ -271,54 +232,20 @@ class ProxyLogging: elif isinstance(response, dict): data = response elif isinstance(response, str): - if call_type == "completion": - _chat_response = ModelResponse() - _chat_response.choices[0].message.content = response - - if ( - data.get("stream", None) is not None - and data["stream"] == True - ): - _iterator = ModelResponseIterator( - model_response=_chat_response - ) - return CustomStreamWrapper( - completion_stream=_iterator, - model=data.get("model", ""), - custom_llm_provider="cached_response", - logging_obj=data.get( - "litellm_logging_obj", None - ), - ) - return _response - elif call_type == "text_completion": - if ( - data.get("stream", None) is not None - and data["stream"] == True - ): - _chat_response = ModelResponse() - _chat_response.choices[0].message.content = response - - if ( - data.get("stream", None) is not None - and data["stream"] == True - ): - _iterator = ModelResponseIterator( - model_response=_chat_response - ) - return TextCompletionStreamWrapper( - completion_stream=_iterator, - model=data.get("model", ""), - ) - else: - _response = TextCompletionResponse() - _response.choices[0].text = response - return _response + if ( + call_type == "completion" + or call_type == "text_completion" + ): + raise RejectedRequestError( + message=response, + model=data.get("model", ""), + llm_provider="", + request_data=data, + ) else: raise HTTPException( status_code=400, detail={"error": response} ) - print_verbose(f"final data being sent to {call_type} call: {data}") return data except Exception as e: From b41f30ca6097548e2e305882f51ce07fc82677a0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 20 May 2024 12:32:19 -0700 Subject: [PATCH 3/4] fix(proxy_server.py): fixes for making rejected responses work with streaming --- litellm/proxy/_super_secret_config.yaml | 4 +++ .../proxy/hooks/prompt_injection_detection.py | 6 ++-- litellm/proxy/proxy_server.py | 34 +++++++++---------- litellm/utils.py | 12 +++++-- 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 42b36950b7..8db3eea3e6 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -21,4 +21,8 @@ router_settings: litellm_settings: callbacks: ["detect_prompt_injection"] + prompt_injection_params: + heuristics_check: true + similarity_check: true + reject_as_response: true diff --git a/litellm/proxy/hooks/prompt_injection_detection.py b/litellm/proxy/hooks/prompt_injection_detection.py index 87cae71a8a..08dbedd8c8 100644 --- a/litellm/proxy/hooks/prompt_injection_detection.py +++ b/litellm/proxy/hooks/prompt_injection_detection.py @@ -193,13 +193,15 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger): return data except HTTPException as e: + if ( e.status_code == 400 and isinstance(e.detail, dict) and "error" in e.detail + and self.prompt_injection_params is not None + and self.prompt_injection_params.reject_as_response ): - if self.prompt_injection_params.reject_as_response: - return e.detail["error"] + return e.detail["error"] raise e except Exception as e: traceback.print_exc() diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6b395e1382..016db6ea32 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3894,7 +3894,7 @@ async def chat_completion( if data.get("stream", None) is not None and data["stream"] == True: _iterator = litellm.utils.ModelResponseIterator( - model_response=_chat_response + model_response=_chat_response, convert_to_delta=True ) _streaming_response = litellm.CustomStreamWrapper( completion_stream=_iterator, @@ -3903,7 +3903,7 @@ async def chat_completion( logging_obj=data.get("litellm_logging_obj", None), ) selected_data_generator = select_data_generator( - response=e.message, + response=_streaming_response, user_api_key_dict=user_api_key_dict, request_data=_data, ) @@ -4037,20 +4037,6 @@ async def completion( user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" ) - if isinstance(data, litellm.TextCompletionResponse): - return data - elif isinstance(data, litellm.TextCompletionStreamWrapper): - selected_data_generator = select_data_generator( - response=data, - user_api_key_dict=user_api_key_dict, - request_data={}, - ) - - return StreamingResponse( - selected_data_generator, - media_type="text/event-stream", - ) - ### ROUTE THE REQUESTs ### router_model_names = llm_router.model_names if llm_router is not None else [] # skip router if user passed their key @@ -4152,12 +4138,24 @@ async def completion( _chat_response.usage = _usage # type: ignore _chat_response.choices[0].message.content = e.message # type: ignore _iterator = litellm.utils.ModelResponseIterator( - model_response=_chat_response + model_response=_chat_response, convert_to_delta=True ) - return litellm.TextCompletionStreamWrapper( + _streaming_response = litellm.TextCompletionStreamWrapper( completion_stream=_iterator, model=_data.get("model", ""), ) + + selected_data_generator = select_data_generator( + response=_streaming_response, + user_api_key_dict=user_api_key_dict, + request_data=data, + ) + + return StreamingResponse( + selected_data_generator, + media_type="text/event-stream", + headers={}, + ) else: _response = litellm.TextCompletionResponse() _response.choices[0].text = e.message diff --git a/litellm/utils.py b/litellm/utils.py index 1e0485755c..5029e8c61c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6440,6 +6440,7 @@ def get_formatted_prompt( "image_generation", "audio_transcription", "moderation", + "text_completion", ], ) -> str: """ @@ -6452,6 +6453,8 @@ def get_formatted_prompt( for m in data["messages"]: if "content" in m and isinstance(m["content"], str): prompt += m["content"] + elif call_type == "text_completion": + prompt = data["prompt"] elif call_type == "embedding" or call_type == "moderation": if isinstance(data["input"], str): prompt = data["input"] @@ -12190,8 +12193,13 @@ def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str: class ModelResponseIterator: - def __init__(self, model_response): - self.model_response = model_response + def __init__(self, model_response: ModelResponse, convert_to_delta: bool = False): + if convert_to_delta == True: + self.model_response = ModelResponse(stream=True) + _delta = self.model_response.choices[0].delta # type: ignore + _delta.content = model_response.choices[0].message.content # type: ignore + else: + self.model_response = model_response self.is_done = False # Sync iterator From bc3c06bc74bd0768fb0f8d258b1f54d4e5226626 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 20 May 2024 12:45:03 -0700 Subject: [PATCH 4/4] docs(call_hooks.md): update docs --- docs/my-website/docs/proxy/call_hooks.md | 147 ++++++++++++++++++++--- 1 file changed, 131 insertions(+), 16 deletions(-) diff --git a/docs/my-website/docs/proxy/call_hooks.md b/docs/my-website/docs/proxy/call_hooks.md index 3195e2e5aa..3a8726e879 100644 --- a/docs/my-website/docs/proxy/call_hooks.md +++ b/docs/my-website/docs/proxy/call_hooks.md @@ -25,26 +25,45 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit def __init__(self): pass - #### ASYNC #### - - async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): - pass - - async def async_log_pre_api_call(self, model, messages, kwargs): - pass - - async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - pass - - async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): - pass - #### CALL HOOKS - proxy only #### - async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]): + async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ]) -> Optional[dict, str, Exception]: data["model"] = "my-new-model" return data + async def async_post_call_failure_hook( + self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth + ): + pass + + async def async_post_call_success_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + pass + + async def async_moderation_hook( # call made in parallel to llm api call + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + pass + + async def async_post_call_streaming_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response: str, + ): + pass proxy_handler_instance = MyCustomHandler() ``` @@ -190,4 +209,100 @@ general_settings: **Result** - \ No newline at end of file + + +## Advanced - Return rejected message as response + +For chat completions and text completion calls, you can return a rejected message as a user response. + +Do this by returning a string. LiteLLM takes care of returning the response in the correct format depending on the endpoint and if it's streaming/non-streaming. + +For non-chat/text completion endpoints, this response is returned as a 400 status code exception. + + +### 1. Create Custom Handler + +```python +from litellm.integrations.custom_logger import CustomLogger +import litellm +from litellm.utils import get_formatted_prompt + +# This file includes the custom callbacks for LiteLLM Proxy +# Once defined, these can be passed in proxy_config.yaml +class MyCustomHandler(CustomLogger): + def __init__(self): + pass + + #### CALL HOOKS - proxy only #### + + async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ]) -> Optional[dict, str, Exception]: + formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) + + if "Hello world" in formatted_prompt: + return "This is an invalid response" + + return data + +proxy_handler_instance = MyCustomHandler() +``` + +### 2. Update config.yaml + +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: gpt-3.5-turbo + +litellm_settings: + callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] +``` + + +### 3. Test it! + +```shell +$ litellm /path/to/config.yaml +``` +```shell +curl --location 'http://0.0.0.0:4000/chat/completions' \ + --data ' { + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "Hello world" + } + ], + }' +``` + +**Expected Response** + +``` +{ + "id": "chatcmpl-d00bbede-2d90-4618-bf7b-11a1c23cf360", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "This is an invalid response.", # 👈 REJECTED RESPONSE + "role": "assistant" + } + } + ], + "created": 1716234198, + "model": null, + "object": "chat.completion", + "system_fingerprint": null, + "usage": {} +} +``` \ No newline at end of file