fix(vertex_ai.py): add support for real async streaming + completion calls

This commit is contained in:
Krrish Dholakia 2023-12-13 11:53:55 -08:00
parent 07015843ac
commit 69c29f8f86
5 changed files with 134 additions and 49 deletions

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm import litellm
import httpx import httpx
@ -108,37 +108,38 @@ def completion(
mode = "chat" mode = "chat"
request_str += f"llm_model = ChatModel.from_pretrained({model})\n" request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
elif model in litellm.vertex_text_models: elif model in litellm.vertex_text_models:
text_model = TextGenerationModel.from_pretrained(model) llm_model = TextGenerationModel.from_pretrained(model)
mode = "text" mode = "text"
request_str += f"text_model = TextGenerationModel.from_pretrained({model})\n" request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
elif model in litellm.vertex_code_text_models: elif model in litellm.vertex_code_text_models:
text_model = CodeGenerationModel.from_pretrained(model) llm_model = CodeGenerationModel.from_pretrained(model)
mode = "text" mode = "text"
request_str += f"text_model = CodeGenerationModel.from_pretrained({model})\n" request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
else: # vertex_code_llm_models else: # vertex_code_llm_models
llm_model = CodeChatModel.from_pretrained(model) llm_model = CodeChatModel.from_pretrained(model)
mode = "chat" mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
if acompletion == True and model in litellm.vertex_language_models: # [TODO] expand support to vertex ai chat + text models if acompletion == True: # [TODO] expand support to vertex ai chat + text models
if optional_params.get("stream", False) is True: if optional_params.get("stream", False) is True:
# async streaming # async streaming
pass return async_streaming(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, **optional_params)
return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, **optional_params) return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, **optional_params)
if mode == "": if mode == "":
chat = llm_model.start_chat() chat = llm_model.start_chat()
request_str+= f"chat = llm_model.start_chat()\n" request_str+= f"chat = llm_model.start_chat()\n"
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n" stream = optional_params.pop("stream")
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params)) model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params})).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params)) response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params))
@ -165,20 +166,19 @@ def completion(
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
completion_response = chat.send_message(prompt, **optional_params).text completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text": elif mode == "text":
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
request_str += f"text_model.predict_streaming({prompt}, **{optional_params})\n" request_str += f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = text_model.predict_streaming(prompt, **optional_params) model_response = llm_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"text_model.predict({prompt}, **{optional_params}).text\n" request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
completion_response = text_model.predict(prompt, **optional_params).text completion_response = llm_model.predict(prompt, **optional_params).text
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -216,7 +216,7 @@ def completion(
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params): async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, encoding=None, **optional_params):
""" """
Add support for acompletion calls for gemini-pro Add support for acompletion calls for gemini-pro
""" """
@ -224,19 +224,31 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
if mode == "": if mode == "":
# gemini-pro # gemini-pro
llm_model = llm_model.start_chat() chat = llm_model.start_chat()
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await llm_model.send_message_async(prompt, generation_config=GenerationConfig(**optional_params)) response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
completion_response = response_obj.text completion_response = response_obj.text
response_obj = response_obj._raw_response response_obj = response_obj._raw_response
elif mode == "chat": elif mode == "chat":
# chat-bison etc. # chat-bison etc.
pass chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await chat.send_message_async(prompt, **optional_params)
completion_response = response_obj.text
elif mode == "text": elif mode == "text":
# gecko etc. # gecko etc.
pass request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT ## RESPONSE OBJECT
if len(str(completion_response)) > 0: if len(str(completion_response)) > 0:
@ -252,13 +264,53 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count, usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count, completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count) total_tokens=response_obj.usage_metadata.total_token_count)
else:
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def async_streaming(): async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params):
""" """
Add support for async streaming calls for gemini-pro Add support for async streaming calls for gemini-pro
""" """
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
stream = optional_params.pop("stream")
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
optional_params["stream"] = True
elif mode == "chat":
chat = llm_model.start_chat()
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
request_str += f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response = chat.send_message_streaming_async(prompt, **optional_params)
optional_params["stream"] = True
elif mode == "text":
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
request_str += f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response = llm_model.predict_streaming_async(prompt, **optional_params)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="vertex_ai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls

View file

@ -1157,7 +1157,7 @@ def completion(
acompletion=acompletion acompletion=acompletion
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True and acompletion == False:
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging
) )

View file

@ -73,7 +73,7 @@ def test_vertex_ai():
litellm.vertex_project = "hardy-device-386718" litellm.vertex_project = "hardy-device-386718"
test_models = random.sample(test_models, 4) test_models = random.sample(test_models, 4)
test_models = litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
try: try:
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
@ -87,7 +87,7 @@ def test_vertex_ai():
assert len(response.choices[0].message.content) > 1 assert len(response.choices[0].message.content) > 1
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_vertex_ai() # test_vertex_ai()
def test_vertex_ai_stream(): def test_vertex_ai_stream():
load_vertex_ai_credentials() load_vertex_ai_credentials()
@ -120,16 +120,48 @@ def test_vertex_ai_stream():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_vertexai_response(): async def test_async_vertexai_response():
import random
load_vertex_ai_credentials() load_vertex_ai_credentials()
user_message = "Hello, how are you?" test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
messages = [{"content": user_message, "role": "user"}] test_models = random.sample(test_models, 4)
try: test_models += litellm.vertex_language_models # always test gemini-pro
response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5) for model in test_models:
# response = await response print(f'model being tested in async call: {model}')
print(f"response: {response}") try:
except litellm.Timeout as e: user_message = "Hello, how are you?"
pass messages = [{"content": user_message, "role": "user"}]
except Exception as e: response = await acompletion(model=model, messages=messages, temperature=0.7, timeout=5)
pytest.fail(f"An exception occurred: {e}") print(f"response: {response}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_async_vertexai_response()) # asyncio.run(test_async_vertexai_response())
@pytest.mark.asyncio
async def test_async_vertexai_streaming_response():
import random
load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = random.sample(test_models, 4)
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models:
try:
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5, stream=True)
print(f"response: {response}")
complete_response = ""
async for chunk in response:
print(f"chunk: {chunk}")
complete_response += chunk.choices[0].delta.content
print(f"complete_response: {complete_response}")
assert len(complete_response) > 0
except litellm.Timeout as e:
pass
except Exception as e:
print(e)
pytest.fail(f"An exception occurred: {e}")
# asyncio.run(test_async_vertexai_streaming_response())

View file

@ -19,6 +19,7 @@ import uuid
import aiohttp import aiohttp
import logging import logging
import asyncio, httpx, inspect import asyncio, httpx, inspect
from inspect import iscoroutine
import copy import copy
from tokenizers import Tokenizer from tokenizers import Tokenizer
from dataclasses import ( from dataclasses import (
@ -5769,7 +5770,8 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "azure" or self.custom_llm_provider == "azure"
or self.custom_llm_provider == "custom_openai" or self.custom_llm_provider == "custom_openai"
or self.custom_llm_provider == "text-completion-openai" or self.custom_llm_provider == "text-completion-openai"
or self.custom_llm_provider == "huggingface"): or self.custom_llm_provider == "huggingface"
or self.custom_llm_provider == "vertex_ai"):
async for chunk in self.completion_stream: async for chunk in self.completion_stream:
if chunk == "None" or chunk is None: if chunk == "None" or chunk is None:
raise Exception raise Exception

View file

@ -294,14 +294,21 @@
"max_tokens": 2048, "max_tokens": 2048,
"input_cost_per_token": 0.000000125, "input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125,
"litellm_provider": "vertex_ai-chat-models", "litellm_provider": "vertex_ai-code-text-models",
"mode": "completion" "mode": "completion"
}, },
"code-gecko@latest": { "code-gecko@002": {
"max_tokens": 2048, "max_tokens": 2048,
"input_cost_per_token": 0.000000125, "input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125,
"litellm_provider": "vertex_ai-chat-models", "litellm_provider": "vertex_ai-code-text-models",
"mode": "completion"
},
"code-gecko": {
"max_tokens": 2048,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "vertex_ai-code-text-models",
"mode": "completion" "mode": "completion"
}, },
"codechat-bison": { "codechat-bison": {
@ -340,14 +347,6 @@
"litellm_provider": "palm", "litellm_provider": "palm",
"mode": "chat" "mode": "chat"
}, },
"gemini-pro": {
"max_tokens": 30720,
"max_output_tokens": 2048,
"input_cost_per_token": 0.0000000625,
"output_cost_per_token": 0.000000125,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat"
},
"palm/chat-bison-001": { "palm/chat-bison-001": {
"max_tokens": 4096, "max_tokens": 4096,
"input_cost_per_token": 0.000000125, "input_cost_per_token": 0.000000125,