Merge pull request #1749 from BerriAI/litellm_vertex_ai_model_garden

feat(vertex_ai.py): vertex ai model garden support
This commit is contained in:
Krish Dholakia 2024-02-01 21:52:12 -08:00 committed by GitHub
commit 7fc03bf745
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 220 additions and 38 deletions

View file

@ -75,6 +75,41 @@ class VertexAIConfig:
} }
import asyncio
class TextStreamer:
"""
Fake streaming iterator for Vertex AI Model Garden calls
"""
def __init__(self, text):
self.text = text.split() # let's assume words as a streaming unit
self.index = 0
def __iter__(self):
return self
def __next__(self):
if self.index < len(self.text):
result = self.text[self.index]
self.index += 1
return result
else:
raise StopIteration
def __aiter__(self):
return self
async def __anext__(self):
if self.index < len(self.text):
result = self.text[self.index]
self.index += 1
return result
else:
raise StopAsyncIteration # once we run out of data to stream, we raise this error
def _get_image_bytes_from_url(image_url: str) -> bytes: def _get_image_bytes_from_url(image_url: str) -> bytes:
try: try:
response = requests.get(image_url) response = requests.get(image_url)
@ -236,12 +271,17 @@ def completion(
Part, Part,
GenerationConfig, GenerationConfig,
) )
from google.cloud import aiplatform
from google.protobuf import json_format # type: ignore
from google.protobuf.struct_pb2 import Value # type: ignore
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
import google.auth import google.auth
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
creds, _ = google.auth.default(quota_project_id=vertex_project) creds, _ = google.auth.default(quota_project_id=vertex_project)
vertexai.init(project=vertex_project, location=vertex_location, credentials=creds) vertexai.init(
project=vertex_project, location=vertex_location, credentials=creds
)
## Load Config ## Load Config
config = litellm.VertexAIConfig.get_config() config = litellm.VertexAIConfig.get_config()
@ -275,6 +315,11 @@ def completion(
request_str = "" request_str = ""
response_obj = None response_obj = None
async_client = None
instances = None
client_options = {
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
}
if ( if (
model in litellm.vertex_language_models model in litellm.vertex_language_models
or model in litellm.vertex_vision_models or model in litellm.vertex_vision_models
@ -294,39 +339,51 @@ def completion(
llm_model = CodeGenerationModel.from_pretrained(model) llm_model = CodeGenerationModel.from_pretrained(model)
mode = "text" mode = "text"
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
else: # vertex_code_llm_models elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
llm_model = CodeChatModel.from_pretrained(model) llm_model = CodeChatModel.from_pretrained(model)
mode = "chat" mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
else: # assume vertex model garden
client = aiplatform.gapic.PredictionServiceClient(
client_options=client_options
)
if acompletion == True: # [TODO] expand support to vertex ai chat + text models instances = [optional_params]
instances[0]["prompt"] = prompt
instances = [
json_format.ParseDict(instance_dict, Value())
for instance_dict in instances
]
llm_model = client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
mode = "custom"
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n"
if acompletion == True:
data = {
"llm_model": llm_model,
"mode": mode,
"prompt": prompt,
"logging_obj": logging_obj,
"request_str": request_str,
"model": model,
"model_response": model_response,
"encoding": encoding,
"messages": messages,
"print_verbose": print_verbose,
"client_options": client_options,
"instances": instances,
"vertex_location": vertex_location,
"vertex_project": vertex_project,
**optional_params,
}
if optional_params.get("stream", False) is True: if optional_params.get("stream", False) is True:
# async streaming # async streaming
return async_streaming( return async_streaming(**data)
llm_model=llm_model,
mode=mode, return async_completion(**data)
prompt=prompt,
logging_obj=logging_obj,
request_str=request_str,
model=model,
model_response=model_response,
messages=messages,
print_verbose=print_verbose,
**optional_params,
)
return async_completion(
llm_model=llm_model,
mode=mode,
prompt=prompt,
logging_obj=logging_obj,
request_str=request_str,
model=model,
model_response=model_response,
encoding=encoding,
messages=messages,
print_verbose=print_verbose,
**optional_params,
)
if mode == "vision": if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
@ -471,7 +528,36 @@ def completion(
}, },
) )
completion_response = llm_model.predict(prompt, **optional_params).text completion_response = llm_model.predict(prompt, **optional_params).text
elif mode == "custom":
"""
Vertex AI Model Garden
"""
request_str += (
f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = client.predict(
endpoint=llm_model,
instances=instances,
).predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if "stream" in optional_params and optional_params["stream"] == True:
response = TextStreamer(completion_response)
return response
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response input=prompt, api_key=None, original_response=completion_response
@ -539,6 +625,10 @@ async def async_completion(
encoding=None, encoding=None,
messages=None, messages=None,
print_verbose=None, print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
vertex_location=None,
**optional_params, **optional_params,
): ):
""" """
@ -627,7 +717,43 @@ async def async_completion(
) )
response_obj = await llm_model.predict_async(prompt, **optional_params) response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text completion_response = response_obj.text
elif mode == "custom":
"""
Vertex AI Model Garden
"""
from google.cloud import aiplatform
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await async_client.predict(
endpoint=llm_model,
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response input=prompt, api_key=None, original_response=completion_response
@ -657,14 +783,12 @@ async def async_completion(
# init prompt tokens # init prompt tokens
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter # this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0 prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
if response_obj is not None: if response_obj is not None and (
if hasattr(response_obj, "usage_metadata") and hasattr( hasattr(response_obj, "usage_metadata")
response_obj.usage_metadata, "prompt_token_count" and hasattr(response_obj.usage_metadata, "prompt_token_count")
): ):
prompt_tokens = response_obj.usage_metadata.prompt_token_count prompt_tokens = response_obj.usage_metadata.prompt_token_count
completion_tokens = ( completion_tokens = response_obj.usage_metadata.candidates_token_count
response_obj.usage_metadata.candidates_token_count
)
else: else:
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len( completion_tokens = len(
@ -693,8 +817,13 @@ async def async_streaming(
model_response: ModelResponse, model_response: ModelResponse,
logging_obj=None, logging_obj=None,
request_str=None, request_str=None,
encoding=None,
messages=None, messages=None,
print_verbose=None, print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
vertex_location=None,
**optional_params, **optional_params,
): ):
""" """
@ -763,15 +892,47 @@ async def async_streaming(
}, },
) )
response = llm_model.predict_streaming_async(prompt, **optional_params) response = llm_model.predict_streaming_async(prompt, **optional_params)
elif mode == "custom":
from google.cloud import aiplatform
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await async_client.predict(
endpoint=llm_model,
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if "stream" in optional_params and optional_params["stream"] == True:
response = TextStreamer(completion_response)
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=response, completion_stream=response,
model=model, model=model,
custom_llm_provider="vertex_ai", custom_llm_provider="vertex_ai",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
async for transformed_chunk in streamwrapper: return streamwrapper
yield transformed_chunk
def embedding(): def embedding():

View file

@ -279,6 +279,9 @@ def test_completion_azure_gpt4_vision():
except openai.RateLimitError as e: except openai.RateLimitError as e:
print("got a rate liimt error", e) print("got a rate liimt error", e)
pass pass
except openai.APIStatusError as e:
print("got an api status error", e)
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -465,6 +465,7 @@ def test_completion_mistral_api_stream():
def test_completion_deep_infra_stream(): def test_completion_deep_infra_stream():
# deep infra currently includes role in the 2nd chunk # deep infra currently includes role in the 2nd chunk
# waiting for them to make a fix on this # waiting for them to make a fix on this
litellm.set_verbose = True
try: try:
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},

View file

@ -6977,6 +6977,21 @@ def exception_type(
llm_provider="azure", llm_provider="azure",
response=original_exception.response, response=original_exception.response,
) )
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"AzureException - {original_exception.message}",
model=model,
llm_provider="azure",
response=original_exception.response,
)
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
raise Timeout(
message=f"AzureException - {original_exception.message}",
model=model,
llm_provider="azure",
)
else: else:
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise APIError(
@ -8061,6 +8076,7 @@ class CustomStreamWrapper:
if len(original_chunk.choices) > 0: if len(original_chunk.choices) > 0:
try: try:
delta = dict(original_chunk.choices[0].delta) delta = dict(original_chunk.choices[0].delta)
print_verbose(f"original delta: {delta}")
model_response.choices[0].delta = Delta(**delta) model_response.choices[0].delta = Delta(**delta)
except Exception as e: except Exception as e:
model_response.choices[0].delta = Delta() model_response.choices[0].delta = Delta()
@ -8069,6 +8085,7 @@ class CustomStreamWrapper:
model_response.system_fingerprint = ( model_response.system_fingerprint = (
original_chunk.system_fingerprint original_chunk.system_fingerprint
) )
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
if self.sent_first_chunk == False: if self.sent_first_chunk == False:
model_response.choices[0].delta["role"] = "assistant" model_response.choices[0].delta["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True