diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 42b36950b..8db3eea3e 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -21,4 +21,8 @@ router_settings: litellm_settings: callbacks: ["detect_prompt_injection"] + prompt_injection_params: + heuristics_check: true + similarity_check: true + reject_as_response: true diff --git a/litellm/proxy/hooks/prompt_injection_detection.py b/litellm/proxy/hooks/prompt_injection_detection.py index 87cae71a8..08dbedd8c 100644 --- a/litellm/proxy/hooks/prompt_injection_detection.py +++ b/litellm/proxy/hooks/prompt_injection_detection.py @@ -193,13 +193,15 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger): return data except HTTPException as e: + if ( e.status_code == 400 and isinstance(e.detail, dict) and "error" in e.detail + and self.prompt_injection_params is not None + and self.prompt_injection_params.reject_as_response ): - if self.prompt_injection_params.reject_as_response: - return e.detail["error"] + return e.detail["error"] raise e except Exception as e: traceback.print_exc() diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6b395e138..016db6ea3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3894,7 +3894,7 @@ async def chat_completion( if data.get("stream", None) is not None and data["stream"] == True: _iterator = litellm.utils.ModelResponseIterator( - model_response=_chat_response + model_response=_chat_response, convert_to_delta=True ) _streaming_response = litellm.CustomStreamWrapper( completion_stream=_iterator, @@ -3903,7 +3903,7 @@ async def chat_completion( logging_obj=data.get("litellm_logging_obj", None), ) selected_data_generator = select_data_generator( - response=e.message, + response=_streaming_response, user_api_key_dict=user_api_key_dict, request_data=_data, ) @@ -4037,20 +4037,6 @@ async def completion( user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" ) - if isinstance(data, litellm.TextCompletionResponse): - return data - elif isinstance(data, litellm.TextCompletionStreamWrapper): - selected_data_generator = select_data_generator( - response=data, - user_api_key_dict=user_api_key_dict, - request_data={}, - ) - - return StreamingResponse( - selected_data_generator, - media_type="text/event-stream", - ) - ### ROUTE THE REQUESTs ### router_model_names = llm_router.model_names if llm_router is not None else [] # skip router if user passed their key @@ -4152,12 +4138,24 @@ async def completion( _chat_response.usage = _usage # type: ignore _chat_response.choices[0].message.content = e.message # type: ignore _iterator = litellm.utils.ModelResponseIterator( - model_response=_chat_response + model_response=_chat_response, convert_to_delta=True ) - return litellm.TextCompletionStreamWrapper( + _streaming_response = litellm.TextCompletionStreamWrapper( completion_stream=_iterator, model=_data.get("model", ""), ) + + selected_data_generator = select_data_generator( + response=_streaming_response, + user_api_key_dict=user_api_key_dict, + request_data=data, + ) + + return StreamingResponse( + selected_data_generator, + media_type="text/event-stream", + headers={}, + ) else: _response = litellm.TextCompletionResponse() _response.choices[0].text = e.message diff --git a/litellm/utils.py b/litellm/utils.py index 1e0485755..5029e8c61 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6440,6 +6440,7 @@ def get_formatted_prompt( "image_generation", "audio_transcription", "moderation", + "text_completion", ], ) -> str: """ @@ -6452,6 +6453,8 @@ def get_formatted_prompt( for m in data["messages"]: if "content" in m and isinstance(m["content"], str): prompt += m["content"] + elif call_type == "text_completion": + prompt = data["prompt"] elif call_type == "embedding" or call_type == "moderation": if isinstance(data["input"], str): prompt = data["input"] @@ -12190,8 +12193,13 @@ def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str: class ModelResponseIterator: - def __init__(self, model_response): - self.model_response = model_response + def __init__(self, model_response: ModelResponse, convert_to_delta: bool = False): + if convert_to_delta == True: + self.model_response = ModelResponse(stream=True) + _delta = self.model_response.choices[0].delta # type: ignore + _delta.content = model_response.choices[0].message.content # type: ignore + else: + self.model_response = model_response self.is_done = False # Sync iterator