fix(vertex_ai_anthropic.py): support streaming, async completion, async streaming for vertex ai anthropic

This commit is contained in:
Krrish Dholakia 2024-04-05 09:27:48 -07:00
parent eb34306099
commit f0c4ff6e60
8 changed files with 373 additions and 14 deletions

View file

@ -4849,6 +4849,17 @@ def get_optional_params(
print_verbose(
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
)
elif (
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
):
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAIAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -5229,7 +5240,9 @@ def get_optional_params(
extra_body # openai client supports `extra_body` param
)
else: # assume passing in params for openai/azure openai
print_verbose(f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE")
print_verbose(
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
)
supported_params = get_supported_openai_params(
model=model, custom_llm_provider="openai"
)
@ -8710,6 +8723,58 @@ class CustomStreamWrapper:
"finish_reason": finish_reason,
}
def handle_vertexai_anthropic_chunk(self, chunk):
"""
- MessageStartEvent(message=Message(id='msg_01LeRRgvX4gwkX3ryBVgtuYZ', content=[], model='claude-3-sonnet-20240229', role='assistant', stop_reason=None, stop_sequence=None, type='message', usage=Usage(input_tokens=8, output_tokens=1)), type='message_start'); custom_llm_provider: vertex_ai
- ContentBlockStartEvent(content_block=ContentBlock(text='', type='text'), index=0, type='content_block_start'); custom_llm_provider: vertex_ai
- ContentBlockDeltaEvent(delta=TextDelta(text='Hello', type='text_delta'), index=0, type='content_block_delta'); custom_llm_provider: vertex_ai
"""
text = ""
prompt_tokens = None
completion_tokens = None
is_finished = False
finish_reason = None
type_chunk = getattr(chunk, "type", None)
if type_chunk == "message_start":
message = getattr(chunk, "message", None)
text = "" # lets us return a chunk with usage to user
_usage = getattr(message, "usage", None)
if _usage is not None:
prompt_tokens = getattr(_usage, "input_tokens", None)
completion_tokens = getattr(_usage, "output_tokens", None)
elif type_chunk == "content_block_delta":
"""
Anthropic content chunk
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
"""
delta = getattr(chunk, "delta", None)
if delta is not None:
text = getattr(delta, "text", "")
else:
text = ""
elif type_chunk == "message_delta":
"""
Anthropic
chunk = {'type': 'message_delta', 'delta': {'stop_reason': 'max_tokens', 'stop_sequence': None}, 'usage': {'output_tokens': 10}}
"""
# TODO - get usage from this chunk, set in response
delta = getattr(chunk, "delta", None)
if delta is not None:
finish_reason = getattr(delta, "stop_reason", "stop")
is_finished = True
_usage = getattr(chunk, "usage", None)
if _usage is not None:
prompt_tokens = getattr(_usage, "input_tokens", None)
completion_tokens = getattr(_usage, "output_tokens", None)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
}
def handle_together_ai_chunk(self, chunk):
chunk = chunk.decode("utf-8")
text = ""
@ -9377,7 +9442,33 @@ class CustomStreamWrapper:
else:
completion_obj["content"] = str(chunk)
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
if hasattr(chunk, "candidates") == True:
if self.model.startswith("claude-3"):
response_obj = self.handle_vertexai_anthropic_chunk(chunk=chunk)
if response_obj is None:
return
completion_obj["content"] = response_obj["text"]
if response_obj.get("prompt_tokens", None) is not None:
model_response.usage.prompt_tokens = response_obj[
"prompt_tokens"
]
if response_obj.get("completion_tokens", None) is not None:
model_response.usage.completion_tokens = response_obj[
"completion_tokens"
]
if hasattr(model_response.usage, "prompt_tokens"):
model_response.usage.total_tokens = (
getattr(model_response.usage, "total_tokens", 0)
+ model_response.usage.prompt_tokens
)
if hasattr(model_response.usage, "completion_tokens"):
model_response.usage.total_tokens = (
getattr(model_response.usage, "total_tokens", 0)
+ model_response.usage.completion_tokens
)
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif hasattr(chunk, "candidates") == True:
try:
try:
completion_obj["content"] = chunk.text
@ -9629,6 +9720,18 @@ class CustomStreamWrapper:
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
## RETURN ARG
if (
"content" in completion_obj
and isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) == 0
and hasattr(model_response.usage, "prompt_tokens")
):
if self.sent_first_chunk == False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
print_verbose(f"returning model_response: {model_response}")
return model_response
elif (
"content" in completion_obj
and isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) > 0