forked from phoenix/litellm-mirror
fix(proxy_server.py): fixes for making rejected responses work with streaming
This commit is contained in:
parent
f11f207ae6
commit
b41f30ca60
4 changed files with 34 additions and 22 deletions
|
@ -21,4 +21,8 @@ router_settings:
|
|||
|
||||
litellm_settings:
|
||||
callbacks: ["detect_prompt_injection"]
|
||||
prompt_injection_params:
|
||||
heuristics_check: true
|
||||
similarity_check: true
|
||||
reject_as_response: true
|
||||
|
||||
|
|
|
@ -193,13 +193,15 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
return data
|
||||
|
||||
except HTTPException as e:
|
||||
|
||||
if (
|
||||
e.status_code == 400
|
||||
and isinstance(e.detail, dict)
|
||||
and "error" in e.detail
|
||||
and self.prompt_injection_params is not None
|
||||
and self.prompt_injection_params.reject_as_response
|
||||
):
|
||||
if self.prompt_injection_params.reject_as_response:
|
||||
return e.detail["error"]
|
||||
return e.detail["error"]
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
|
|
@ -3894,7 +3894,7 @@ async def chat_completion(
|
|||
|
||||
if data.get("stream", None) is not None and data["stream"] == True:
|
||||
_iterator = litellm.utils.ModelResponseIterator(
|
||||
model_response=_chat_response
|
||||
model_response=_chat_response, convert_to_delta=True
|
||||
)
|
||||
_streaming_response = litellm.CustomStreamWrapper(
|
||||
completion_stream=_iterator,
|
||||
|
@ -3903,7 +3903,7 @@ async def chat_completion(
|
|||
logging_obj=data.get("litellm_logging_obj", None),
|
||||
)
|
||||
selected_data_generator = select_data_generator(
|
||||
response=e.message,
|
||||
response=_streaming_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=_data,
|
||||
)
|
||||
|
@ -4037,20 +4037,6 @@ async def completion(
|
|||
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
|
||||
)
|
||||
|
||||
if isinstance(data, litellm.TextCompletionResponse):
|
||||
return data
|
||||
elif isinstance(data, litellm.TextCompletionStreamWrapper):
|
||||
selected_data_generator = select_data_generator(
|
||||
response=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data={},
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
### ROUTE THE REQUESTs ###
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
# skip router if user passed their key
|
||||
|
@ -4152,12 +4138,24 @@ async def completion(
|
|||
_chat_response.usage = _usage # type: ignore
|
||||
_chat_response.choices[0].message.content = e.message # type: ignore
|
||||
_iterator = litellm.utils.ModelResponseIterator(
|
||||
model_response=_chat_response
|
||||
model_response=_chat_response, convert_to_delta=True
|
||||
)
|
||||
return litellm.TextCompletionStreamWrapper(
|
||||
_streaming_response = litellm.TextCompletionStreamWrapper(
|
||||
completion_stream=_iterator,
|
||||
model=_data.get("model", ""),
|
||||
)
|
||||
|
||||
selected_data_generator = select_data_generator(
|
||||
response=_streaming_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers={},
|
||||
)
|
||||
else:
|
||||
_response = litellm.TextCompletionResponse()
|
||||
_response.choices[0].text = e.message
|
||||
|
|
|
@ -6440,6 +6440,7 @@ def get_formatted_prompt(
|
|||
"image_generation",
|
||||
"audio_transcription",
|
||||
"moderation",
|
||||
"text_completion",
|
||||
],
|
||||
) -> str:
|
||||
"""
|
||||
|
@ -6452,6 +6453,8 @@ def get_formatted_prompt(
|
|||
for m in data["messages"]:
|
||||
if "content" in m and isinstance(m["content"], str):
|
||||
prompt += m["content"]
|
||||
elif call_type == "text_completion":
|
||||
prompt = data["prompt"]
|
||||
elif call_type == "embedding" or call_type == "moderation":
|
||||
if isinstance(data["input"], str):
|
||||
prompt = data["input"]
|
||||
|
@ -12190,8 +12193,13 @@ def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str:
|
|||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, model_response):
|
||||
self.model_response = model_response
|
||||
def __init__(self, model_response: ModelResponse, convert_to_delta: bool = False):
|
||||
if convert_to_delta == True:
|
||||
self.model_response = ModelResponse(stream=True)
|
||||
_delta = self.model_response.choices[0].delta # type: ignore
|
||||
_delta.content = model_response.choices[0].message.content # type: ignore
|
||||
else:
|
||||
self.model_response = model_response
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue