From bd37a9cb5e3e34aa4550a41c43ab9c28db825f6f Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 23 Jan 2024 11:12:16 -0800 Subject: [PATCH] (fix) proxy - streaming sagemaker --- litellm/proxy/proxy_server.py | 26 ++++++++++++++++++-------- litellm/proxy/tests/test_openai_js.js | 2 +- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 78e756a2a..f4eb04874 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1658,11 +1658,16 @@ async def completion( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses custom_headers = {"x-litellm-model-id": model_id} - return StreamingResponse( - async_data_generator( - user_api_key_dict=user_api_key_dict, + stream_content = async_data_generator( + user_api_key_dict=user_api_key_dict, + response=response, + ) + if response.custom_llm_provider == "sagemaker": + stream_content = data_generator( response=response, - ), + ) + return StreamingResponse( + stream_content, media_type="text/event-stream", headers=custom_headers, ) @@ -1820,11 +1825,16 @@ async def chat_completion( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses custom_headers = {"x-litellm-model-id": model_id} - return StreamingResponse( - async_data_generator( - user_api_key_dict=user_api_key_dict, + stream_content = async_data_generator( + user_api_key_dict=user_api_key_dict, + response=response, + ) + if response.custom_llm_provider == "sagemaker": + stream_content = data_generator( response=response, - ), + ) + return StreamingResponse( + stream_content, media_type="text/event-stream", headers=custom_headers, ) diff --git a/litellm/proxy/tests/test_openai_js.js b/litellm/proxy/tests/test_openai_js.js index 7e74eeca3..c0f25cf05 100644 --- a/litellm/proxy/tests/test_openai_js.js +++ b/litellm/proxy/tests/test_openai_js.js @@ -4,7 +4,7 @@ const openai = require('openai'); process.env.DEBUG=false; async function runOpenAI() { const client = new openai.OpenAI({ - apiKey: 'sk-yPX56TDqBpr23W7ruFG3Yg', + apiKey: 'sk-JkKeNi6WpWDngBsghJ6B9g', baseURL: 'http://0.0.0.0:8000' });