From ed5fc3d1f9db5b2f6228b0b64c05e7b98b1a8d8b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 3 Jul 2024 18:43:46 -0700 Subject: [PATCH] fix(utils.py): fix vertex anthropic streaming --- litellm/llms/vertex_ai_anthropic.py | 187 ------------------ litellm/main.py | 24 +-- .../tests/test_amazing_vertex_completion.py | 9 +- litellm/utils.py | 4 + 4 files changed, 22 insertions(+), 202 deletions(-) diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index 71dc2aacda..44a7a448eb 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -303,193 +303,6 @@ def completion( headers=vertex_headers, ) - ## Format Prompt - _is_function_call = False - _is_json_schema = False - messages = copy.deepcopy(messages) - optional_params = copy.deepcopy(optional_params) - # Separate system prompt from rest of message - system_prompt_indices = [] - system_prompt = "" - for idx, message in enumerate(messages): - if message["role"] == "system": - system_prompt += message["content"] - system_prompt_indices.append(idx) - if len(system_prompt_indices) > 0: - for idx in reversed(system_prompt_indices): - messages.pop(idx) - if len(system_prompt) > 0: - optional_params["system"] = system_prompt - # Checks for 'response_schema' support - if passed in - if "response_format" in optional_params: - response_format_chunk = ResponseFormatChunk( - **optional_params["response_format"] # type: ignore - ) - supports_response_schema = litellm.supports_response_schema( - model=model, custom_llm_provider="vertex_ai" - ) - if ( - supports_response_schema is False - and response_format_chunk["type"] == "json_object" - and "response_schema" in response_format_chunk - ): - _is_json_schema = True - user_response_schema_message = response_schema_prompt( - model=model, - response_schema=response_format_chunk["response_schema"], - ) - messages.append( - {"role": "user", "content": user_response_schema_message} - ) - messages.append({"role": "assistant", "content": "{"}) - optional_params.pop("response_format") - # Format rest of message according to anthropic guidelines - try: - messages = prompt_factory( - model=model, messages=messages, custom_llm_provider="anthropic_xml" - ) - except Exception as e: - raise VertexAIError(status_code=400, message=str(e)) - - ## Handle Tool Calling - if "tools" in optional_params: - _is_function_call = True - tool_calling_system_prompt = construct_tool_use_system_prompt( - tools=optional_params["tools"] - ) - optional_params["system"] = ( - optional_params.get("system", "\n") + tool_calling_system_prompt - ) # add the anthropic tool calling prompt to the system prompt - optional_params.pop("tools") - - stream = optional_params.pop("stream", None) - - data = { - "model": model, - "messages": messages, - **optional_params, - } - print_verbose(f"_is_function_call: {_is_function_call}") - - ## Completion Call - - print_verbose( - f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}" - ) - - if acompletion == True: - """ - - async streaming - - async completion - """ - if stream is not None and stream == True: - return async_streaming( - model=model, - messages=messages, - data=data, - print_verbose=print_verbose, - model_response=model_response, - logging_obj=logging_obj, - vertex_project=vertex_project, - vertex_location=vertex_location, - optional_params=optional_params, - client=client, - access_token=access_token, - ) - else: - return async_completion( - model=model, - messages=messages, - data=data, - print_verbose=print_verbose, - model_response=model_response, - logging_obj=logging_obj, - vertex_project=vertex_project, - vertex_location=vertex_location, - optional_params=optional_params, - client=client, - access_token=access_token, - ) - if stream is not None and stream == True: - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - }, - ) - response = vertex_ai_client.messages.create(**data, stream=True) # type: ignore - return response - - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - }, - ) - - vertex_ai_client: Optional[AnthropicVertex] = None - vertex_ai_client = AnthropicVertex() - if vertex_ai_client is not None: - message = vertex_ai_client.messages.create(**data) # type: ignore - - ## LOGGING - logging_obj.post_call( - input=messages, - api_key="", - original_response=message, - additional_args={"complete_input_dict": data}, - ) - - text_content: str = message.content[0].text - ## TOOL CALLING - OUTPUT PARSE - if text_content is not None and contains_tag("invoke", text_content): - function_name = extract_between_tags("tool_name", text_content)[0] - function_arguments_str = extract_between_tags("invoke", text_content)[ - 0 - ].strip() - function_arguments_str = f"{function_arguments_str}" - function_arguments = parse_xml_params(function_arguments_str) - _message = litellm.Message( - tool_calls=[ - { - "id": f"call_{uuid.uuid4()}", - "type": "function", - "function": { - "name": function_name, - "arguments": json.dumps(function_arguments), - }, - } - ], - content=None, - ) - model_response.choices[0].message = _message # type: ignore - else: - if ( - _is_json_schema - ): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb - json_response = "{" + text_content[: text_content.rfind("}") + 1] - model_response.choices[0].message.content = json_response # type: ignore - else: - model_response.choices[0].message.content = text_content # type: ignore - model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason) - - ## CALCULATING USAGE - prompt_tokens = message.usage.input_tokens - completion_tokens = message.usage.output_tokens - - model_response["created"] = int(time.time()) - model_response["model"] = model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - setattr(model_response, "usage", usage) - return model_response except Exception as e: raise VertexAIError(status_code=500, message=str(e)) diff --git a/litellm/main.py b/litellm/main.py index 72eeff2628..e3eee4c3c8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2028,18 +2028,18 @@ def completion( acompletion=acompletion, ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion == False - ): - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="vertex_ai", - logging_obj=logging, - ) - return response + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion == False + ): + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="vertex_ai", + logging_obj=logging, + ) + return response response = model_response elif custom_llm_provider == "predibase": tenant_id = ( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index d8bb6d4328..c4a5ec7ca3 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -203,7 +203,7 @@ def test_vertex_ai_anthropic(): # ) def test_vertex_ai_anthropic_streaming(): try: - # load_vertex_ai_credentials() + load_vertex_ai_credentials() # litellm.set_verbose = True @@ -223,8 +223,9 @@ def test_vertex_ai_anthropic_streaming(): stream=True, ) # print("\nModel Response", response) - for chunk in response: + for idx, chunk in enumerate(response): print(f"chunk: {chunk}") + streaming_format_tests(idx=idx, chunk=chunk) # raise Exception("it worked!") except litellm.RateLimitError as e: @@ -294,8 +295,10 @@ async def test_vertex_ai_anthropic_async_streaming(): stream=True, ) + idx = 0 async for chunk in response: - print(f"chunk: {chunk}") + streaming_format_tests(idx=idx, chunk=chunk) + idx += 1 except litellm.RateLimitError as e: pass except Exception as e: diff --git a/litellm/utils.py b/litellm/utils.py index 6c1814629c..26f90fa57a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8035,6 +8035,10 @@ class CustomStreamWrapper: str_line = chunk if isinstance(chunk, bytes): # Handle binary data str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + text = "" is_finished = False finish_reason = None