mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(utils.py): fix vertex anthropic streaming
This commit is contained in:
parent
7007ace6c2
commit
ed5fc3d1f9
4 changed files with 22 additions and 202 deletions
|
@ -303,193 +303,6 @@ def completion(
|
||||||
headers=vertex_headers,
|
headers=vertex_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
## Format Prompt
|
|
||||||
_is_function_call = False
|
|
||||||
_is_json_schema = False
|
|
||||||
messages = copy.deepcopy(messages)
|
|
||||||
optional_params = copy.deepcopy(optional_params)
|
|
||||||
# Separate system prompt from rest of message
|
|
||||||
system_prompt_indices = []
|
|
||||||
system_prompt = ""
|
|
||||||
for idx, message in enumerate(messages):
|
|
||||||
if message["role"] == "system":
|
|
||||||
system_prompt += message["content"]
|
|
||||||
system_prompt_indices.append(idx)
|
|
||||||
if len(system_prompt_indices) > 0:
|
|
||||||
for idx in reversed(system_prompt_indices):
|
|
||||||
messages.pop(idx)
|
|
||||||
if len(system_prompt) > 0:
|
|
||||||
optional_params["system"] = system_prompt
|
|
||||||
# Checks for 'response_schema' support - if passed in
|
|
||||||
if "response_format" in optional_params:
|
|
||||||
response_format_chunk = ResponseFormatChunk(
|
|
||||||
**optional_params["response_format"] # type: ignore
|
|
||||||
)
|
|
||||||
supports_response_schema = litellm.supports_response_schema(
|
|
||||||
model=model, custom_llm_provider="vertex_ai"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
supports_response_schema is False
|
|
||||||
and response_format_chunk["type"] == "json_object"
|
|
||||||
and "response_schema" in response_format_chunk
|
|
||||||
):
|
|
||||||
_is_json_schema = True
|
|
||||||
user_response_schema_message = response_schema_prompt(
|
|
||||||
model=model,
|
|
||||||
response_schema=response_format_chunk["response_schema"],
|
|
||||||
)
|
|
||||||
messages.append(
|
|
||||||
{"role": "user", "content": user_response_schema_message}
|
|
||||||
)
|
|
||||||
messages.append({"role": "assistant", "content": "{"})
|
|
||||||
optional_params.pop("response_format")
|
|
||||||
# Format rest of message according to anthropic guidelines
|
|
||||||
try:
|
|
||||||
messages = prompt_factory(
|
|
||||||
model=model, messages=messages, custom_llm_provider="anthropic_xml"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise VertexAIError(status_code=400, message=str(e))
|
|
||||||
|
|
||||||
## Handle Tool Calling
|
|
||||||
if "tools" in optional_params:
|
|
||||||
_is_function_call = True
|
|
||||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
|
||||||
tools=optional_params["tools"]
|
|
||||||
)
|
|
||||||
optional_params["system"] = (
|
|
||||||
optional_params.get("system", "\n") + tool_calling_system_prompt
|
|
||||||
) # add the anthropic tool calling prompt to the system prompt
|
|
||||||
optional_params.pop("tools")
|
|
||||||
|
|
||||||
stream = optional_params.pop("stream", None)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
**optional_params,
|
|
||||||
}
|
|
||||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
|
||||||
|
|
||||||
## Completion Call
|
|
||||||
|
|
||||||
print_verbose(
|
|
||||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if acompletion == True:
|
|
||||||
"""
|
|
||||||
- async streaming
|
|
||||||
- async completion
|
|
||||||
"""
|
|
||||||
if stream is not None and stream == True:
|
|
||||||
return async_streaming(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
data=data,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
model_response=model_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
vertex_project=vertex_project,
|
|
||||||
vertex_location=vertex_location,
|
|
||||||
optional_params=optional_params,
|
|
||||||
client=client,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return async_completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
data=data,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
model_response=model_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
vertex_project=vertex_project,
|
|
||||||
vertex_location=vertex_location,
|
|
||||||
optional_params=optional_params,
|
|
||||||
client=client,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
if stream is not None and stream == True:
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response = vertex_ai_client.messages.create(**data, stream=True) # type: ignore
|
|
||||||
return response
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
vertex_ai_client: Optional[AnthropicVertex] = None
|
|
||||||
vertex_ai_client = AnthropicVertex()
|
|
||||||
if vertex_ai_client is not None:
|
|
||||||
message = vertex_ai_client.messages.create(**data) # type: ignore
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=messages,
|
|
||||||
api_key="",
|
|
||||||
original_response=message,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
|
|
||||||
text_content: str = message.content[0].text
|
|
||||||
## TOOL CALLING - OUTPUT PARSE
|
|
||||||
if text_content is not None and contains_tag("invoke", text_content):
|
|
||||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
|
||||||
function_arguments_str = extract_between_tags("invoke", text_content)[
|
|
||||||
0
|
|
||||||
].strip()
|
|
||||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
|
||||||
function_arguments = parse_xml_params(function_arguments_str)
|
|
||||||
_message = litellm.Message(
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": f"call_{uuid.uuid4()}",
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": function_name,
|
|
||||||
"arguments": json.dumps(function_arguments),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
content=None,
|
|
||||||
)
|
|
||||||
model_response.choices[0].message = _message # type: ignore
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
_is_json_schema
|
|
||||||
): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
|
|
||||||
json_response = "{" + text_content[: text_content.rfind("}") + 1]
|
|
||||||
model_response.choices[0].message.content = json_response # type: ignore
|
|
||||||
else:
|
|
||||||
model_response.choices[0].message.content = text_content # type: ignore
|
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
|
|
||||||
|
|
||||||
## CALCULATING USAGE
|
|
||||||
prompt_tokens = message.usage.input_tokens
|
|
||||||
completion_tokens = message.usage.output_tokens
|
|
||||||
|
|
||||||
model_response["created"] = int(time.time())
|
|
||||||
model_response["model"] = model
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", usage)
|
|
||||||
return model_response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise VertexAIError(status_code=500, message=str(e))
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
|
|
@ -2028,18 +2028,18 @@ def completion(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] == True
|
||||||
and acompletion == False
|
and acompletion == False
|
||||||
):
|
):
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
model_response,
|
model_response,
|
||||||
model,
|
model,
|
||||||
custom_llm_provider="vertex_ai",
|
custom_llm_provider="vertex_ai",
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "predibase":
|
elif custom_llm_provider == "predibase":
|
||||||
tenant_id = (
|
tenant_id = (
|
||||||
|
|
|
@ -203,7 +203,7 @@ def test_vertex_ai_anthropic():
|
||||||
# )
|
# )
|
||||||
def test_vertex_ai_anthropic_streaming():
|
def test_vertex_ai_anthropic_streaming():
|
||||||
try:
|
try:
|
||||||
# load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
|
|
||||||
# litellm.set_verbose = True
|
# litellm.set_verbose = True
|
||||||
|
|
||||||
|
@ -223,8 +223,9 @@ def test_vertex_ai_anthropic_streaming():
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
# print("\nModel Response", response)
|
# print("\nModel Response", response)
|
||||||
for chunk in response:
|
for idx, chunk in enumerate(response):
|
||||||
print(f"chunk: {chunk}")
|
print(f"chunk: {chunk}")
|
||||||
|
streaming_format_tests(idx=idx, chunk=chunk)
|
||||||
|
|
||||||
# raise Exception("it worked!")
|
# raise Exception("it worked!")
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
|
@ -294,8 +295,10 @@ async def test_vertex_ai_anthropic_async_streaming():
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
idx = 0
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(f"chunk: {chunk}")
|
streaming_format_tests(idx=idx, chunk=chunk)
|
||||||
|
idx += 1
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -8035,6 +8035,10 @@ class CustomStreamWrapper:
|
||||||
str_line = chunk
|
str_line = chunk
|
||||||
if isinstance(chunk, bytes): # Handle binary data
|
if isinstance(chunk, bytes): # Handle binary data
|
||||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||||
|
index = str_line.find("data:")
|
||||||
|
if index != -1:
|
||||||
|
str_line = str_line[index:]
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue