fix(utils.py): fix vertex anthropic streaming

This commit is contained in:
Krrish Dholakia 2024-07-03 18:43:46 -07:00
parent 7007ace6c2
commit ed5fc3d1f9
4 changed files with 22 additions and 202 deletions

View file

@ -303,193 +303,6 @@ def completion(
headers=vertex_headers,
)
## Format Prompt
_is_function_call = False
_is_json_schema = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
# Separate system prompt from rest of message
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Checks for 'response_schema' support - if passed in
if "response_format" in optional_params:
response_format_chunk = ResponseFormatChunk(
**optional_params["response_format"] # type: ignore
)
supports_response_schema = litellm.supports_response_schema(
model=model, custom_llm_provider="vertex_ai"
)
if (
supports_response_schema is False
and response_format_chunk["type"] == "json_object"
and "response_schema" in response_format_chunk
):
_is_json_schema = True
user_response_schema_message = response_schema_prompt(
model=model,
response_schema=response_format_chunk["response_schema"],
)
messages.append(
{"role": "user", "content": user_response_schema_message}
)
messages.append({"role": "assistant", "content": "{"})
optional_params.pop("response_format")
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic_xml"
)
except Exception as e:
raise VertexAIError(status_code=400, message=str(e))
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
optional_params["system"] = (
optional_params.get("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
stream = optional_params.pop("stream", None)
data = {
"model": model,
"messages": messages,
**optional_params,
}
print_verbose(f"_is_function_call: {_is_function_call}")
## Completion Call
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
)
if acompletion == True:
"""
- async streaming
- async completion
"""
if stream is not None and stream == True:
return async_streaming(
model=model,
messages=messages,
data=data,
print_verbose=print_verbose,
model_response=model_response,
logging_obj=logging_obj,
vertex_project=vertex_project,
vertex_location=vertex_location,
optional_params=optional_params,
client=client,
access_token=access_token,
)
else:
return async_completion(
model=model,
messages=messages,
data=data,
print_verbose=print_verbose,
model_response=model_response,
logging_obj=logging_obj,
vertex_project=vertex_project,
vertex_location=vertex_location,
optional_params=optional_params,
client=client,
access_token=access_token,
)
if stream is not None and stream == True:
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
response = vertex_ai_client.messages.create(**data, stream=True) # type: ignore
return response
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
vertex_ai_client: Optional[AnthropicVertex] = None
vertex_ai_client = AnthropicVertex()
if vertex_ai_client is not None:
message = vertex_ai_client.messages.create(**data) # type: ignore
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=message,
additional_args={"complete_input_dict": data},
)
text_content: str = message.content[0].text
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[
0
].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
else:
if (
_is_json_schema
): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
json_response = "{" + text_content[: text_content.rfind("}") + 1]
model_response.choices[0].message.content = json_response # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
## CALCULATING USAGE
prompt_tokens = message.usage.input_tokens
completion_tokens = message.usage.output_tokens
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))