fix(utils.py): return finish reason for last vertex ai chunk

This commit is contained in:
Krrish Dholakia 2024-02-06 09:21:03 -08:00
parent f2ef32bcee
commit 3afa5230d6
2 changed files with 52 additions and 13 deletions

View file

@ -1746,7 +1746,33 @@ async def async_data_generator(response, user_api_key_dict):
done_message = "[DONE]" done_message = "[DONE]"
yield f"data: {done_message}\n\n" yield f"data: {done_message}\n\n"
except Exception as e: except Exception as e:
yield f"data: {str(e)}\n\n" traceback.print_exc()
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)
verbose_proxy_logger.debug(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
def select_data_generator(response, user_api_key_dict): def select_data_generator(response, user_api_key_dict):
@ -1754,7 +1780,7 @@ def select_data_generator(response, user_api_key_dict):
# since boto3 - sagemaker does not support async calls, we should use a sync data_generator # since boto3 - sagemaker does not support async calls, we should use a sync data_generator
if hasattr( if hasattr(
response, "custom_llm_provider" response, "custom_llm_provider"
) and response.custom_llm_provider in ["sagemaker", "together_ai"]: ) and response.custom_llm_provider in ["sagemaker"]:
return data_generator( return data_generator(
response=response, response=response,
) )
@ -2239,7 +2265,6 @@ async def chat_completion(
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict response=response, user_api_key_dict=user_api_key_dict
) )
return StreamingResponse( return StreamingResponse(
selected_data_generator, selected_data_generator,
media_type="text/event-stream", media_type="text/event-stream",

View file

@ -169,6 +169,8 @@ def map_finish_reason(
return "stop" return "stop"
elif finish_reason == "SAFETY": # vertex ai elif finish_reason == "SAFETY": # vertex ai
return "content_filter" return "content_filter"
elif finish_reason == "STOP": # vertex ai
return "stop"
return finish_reason return finish_reason
@ -1305,7 +1307,7 @@ class Logging:
) )
if callback == "langfuse": if callback == "langfuse":
global langFuseLogger global langFuseLogger
verbose_logger.debug("reaches langfuse for logging!") verbose_logger.debug("reaches langfuse for success logging!")
kwargs = {} kwargs = {}
for k, v in self.model_call_details.items(): for k, v in self.model_call_details.items():
if ( if (
@ -6706,7 +6708,13 @@ def exception_type(
message=f"VertexAIException - {error_str}", message=f"VertexAIException - {error_str}",
model=model, model=model,
llm_provider="vertex_ai", llm_provider="vertex_ai",
response=original_exception.response, response=httpx.Response(
status_code=429,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
),
) )
elif ( elif (
"429 Quota exceeded" in error_str "429 Quota exceeded" in error_str
@ -8341,13 +8349,20 @@ class CustomStreamWrapper:
completion_obj["content"] = chunk.text completion_obj["content"] = chunk.text
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
try: try:
# print(chunk) if hasattr(chunk, "candidates") == True:
if hasattr(chunk, "text"): try:
# vertexAI chunks return completion_obj["content"] = chunk.text
# MultiCandidateTextGenerationResponse(text=' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', _prediction_response=Prediction(predictions=[{'candidates': [{'content': ' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', 'author': '1'}], 'citationMetadata': [{'citations': None}], 'safetyAttributes': [{'blocked': False, 'scores': None, 'categories': None}]}], deployed_model_id='', model_version_id=None, model_resource_name=None, explanations=None), is_blocked=False, safety_attributes={}, candidates=[ ```python if hasattr(chunk.candidates[0], "finish_reason"):
# This Python code says "Hi" 100 times. model_response.choices[
# Create]) 0
completion_obj["content"] = chunk.text ].finish_reason = map_finish_reason(
chunk.candidates[0].finish_reason.name
)
except:
if chunk.candidates[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else: else:
completion_obj["content"] = str(chunk) completion_obj["content"] = str(chunk)
except StopIteration as e: except StopIteration as e:
@ -8636,7 +8651,6 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "ollama_chat" or self.custom_llm_provider == "ollama_chat"
or self.custom_llm_provider == "vertex_ai" or self.custom_llm_provider == "vertex_ai"
): ):
print_verbose(f"INSIDE ASYNC STREAMING!!!")
print_verbose( print_verbose(
f"value of async completion stream: {self.completion_stream}" f"value of async completion stream: {self.completion_stream}"
) )