mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #1749 from BerriAI/litellm_vertex_ai_model_garden
feat(vertex_ai.py): vertex ai model garden support
This commit is contained in:
commit
7fc03bf745
4 changed files with 220 additions and 38 deletions
|
@ -75,6 +75,41 @@ class VertexAIConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class TextStreamer:
|
||||||
|
"""
|
||||||
|
Fake streaming iterator for Vertex AI Model Garden calls
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, text):
|
||||||
|
self.text = text.split() # let's assume words as a streaming unit
|
||||||
|
self.index = 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.index < len(self.text):
|
||||||
|
result = self.text[self.index]
|
||||||
|
self.index += 1
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if self.index < len(self.text):
|
||||||
|
result = self.text[self.index]
|
||||||
|
self.index += 1
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
raise StopAsyncIteration # once we run out of data to stream, we raise this error
|
||||||
|
|
||||||
|
|
||||||
def _get_image_bytes_from_url(image_url: str) -> bytes:
|
def _get_image_bytes_from_url(image_url: str) -> bytes:
|
||||||
try:
|
try:
|
||||||
response = requests.get(image_url)
|
response = requests.get(image_url)
|
||||||
|
@ -236,12 +271,17 @@ def completion(
|
||||||
Part,
|
Part,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
)
|
)
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
from google.protobuf import json_format # type: ignore
|
||||||
|
from google.protobuf.struct_pb2 import Value # type: ignore
|
||||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
|
||||||
import google.auth
|
import google.auth
|
||||||
|
|
||||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||||
vertexai.init(project=vertex_project, location=vertex_location, credentials=creds)
|
vertexai.init(
|
||||||
|
project=vertex_project, location=vertex_location, credentials=creds
|
||||||
|
)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.VertexAIConfig.get_config()
|
config = litellm.VertexAIConfig.get_config()
|
||||||
|
@ -275,6 +315,11 @@ def completion(
|
||||||
|
|
||||||
request_str = ""
|
request_str = ""
|
||||||
response_obj = None
|
response_obj = None
|
||||||
|
async_client = None
|
||||||
|
instances = None
|
||||||
|
client_options = {
|
||||||
|
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
|
||||||
|
}
|
||||||
if (
|
if (
|
||||||
model in litellm.vertex_language_models
|
model in litellm.vertex_language_models
|
||||||
or model in litellm.vertex_vision_models
|
or model in litellm.vertex_vision_models
|
||||||
|
@ -294,39 +339,51 @@ def completion(
|
||||||
llm_model = CodeGenerationModel.from_pretrained(model)
|
llm_model = CodeGenerationModel.from_pretrained(model)
|
||||||
mode = "text"
|
mode = "text"
|
||||||
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
|
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
|
||||||
else: # vertex_code_llm_models
|
elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
|
||||||
llm_model = CodeChatModel.from_pretrained(model)
|
llm_model = CodeChatModel.from_pretrained(model)
|
||||||
mode = "chat"
|
mode = "chat"
|
||||||
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
||||||
|
else: # assume vertex model garden
|
||||||
|
client = aiplatform.gapic.PredictionServiceClient(
|
||||||
|
client_options=client_options
|
||||||
|
)
|
||||||
|
|
||||||
if acompletion == True: # [TODO] expand support to vertex ai chat + text models
|
instances = [optional_params]
|
||||||
|
instances[0]["prompt"] = prompt
|
||||||
|
instances = [
|
||||||
|
json_format.ParseDict(instance_dict, Value())
|
||||||
|
for instance_dict in instances
|
||||||
|
]
|
||||||
|
llm_model = client.endpoint_path(
|
||||||
|
project=vertex_project, location=vertex_location, endpoint=model
|
||||||
|
)
|
||||||
|
|
||||||
|
mode = "custom"
|
||||||
|
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n"
|
||||||
|
|
||||||
|
if acompletion == True:
|
||||||
|
data = {
|
||||||
|
"llm_model": llm_model,
|
||||||
|
"mode": mode,
|
||||||
|
"prompt": prompt,
|
||||||
|
"logging_obj": logging_obj,
|
||||||
|
"request_str": request_str,
|
||||||
|
"model": model,
|
||||||
|
"model_response": model_response,
|
||||||
|
"encoding": encoding,
|
||||||
|
"messages": messages,
|
||||||
|
"print_verbose": print_verbose,
|
||||||
|
"client_options": client_options,
|
||||||
|
"instances": instances,
|
||||||
|
"vertex_location": vertex_location,
|
||||||
|
"vertex_project": vertex_project,
|
||||||
|
**optional_params,
|
||||||
|
}
|
||||||
if optional_params.get("stream", False) is True:
|
if optional_params.get("stream", False) is True:
|
||||||
# async streaming
|
# async streaming
|
||||||
return async_streaming(
|
return async_streaming(**data)
|
||||||
llm_model=llm_model,
|
|
||||||
mode=mode,
|
return async_completion(**data)
|
||||||
prompt=prompt,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
request_str=request_str,
|
|
||||||
model=model,
|
|
||||||
model_response=model_response,
|
|
||||||
messages=messages,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
**optional_params,
|
|
||||||
)
|
|
||||||
return async_completion(
|
|
||||||
llm_model=llm_model,
|
|
||||||
mode=mode,
|
|
||||||
prompt=prompt,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
request_str=request_str,
|
|
||||||
model=model,
|
|
||||||
model_response=model_response,
|
|
||||||
encoding=encoding,
|
|
||||||
messages=messages,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
**optional_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||||
|
@ -471,7 +528,36 @@ def completion(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
completion_response = llm_model.predict(prompt, **optional_params).text
|
completion_response = llm_model.predict(prompt, **optional_params).text
|
||||||
|
elif mode == "custom":
|
||||||
|
"""
|
||||||
|
Vertex AI Model Garden
|
||||||
|
"""
|
||||||
|
request_str += (
|
||||||
|
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||||
|
)
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.predict(
|
||||||
|
endpoint=llm_model,
|
||||||
|
instances=instances,
|
||||||
|
).predictions
|
||||||
|
completion_response = response[0]
|
||||||
|
if (
|
||||||
|
isinstance(completion_response, str)
|
||||||
|
and "\nOutput:\n" in completion_response
|
||||||
|
):
|
||||||
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
|
response = TextStreamer(completion_response)
|
||||||
|
return response
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=prompt, api_key=None, original_response=completion_response
|
input=prompt, api_key=None, original_response=completion_response
|
||||||
|
@ -539,6 +625,10 @@ async def async_completion(
|
||||||
encoding=None,
|
encoding=None,
|
||||||
messages=None,
|
messages=None,
|
||||||
print_verbose=None,
|
print_verbose=None,
|
||||||
|
client_options=None,
|
||||||
|
instances=None,
|
||||||
|
vertex_project=None,
|
||||||
|
vertex_location=None,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -627,7 +717,43 @@ async def async_completion(
|
||||||
)
|
)
|
||||||
response_obj = await llm_model.predict_async(prompt, **optional_params)
|
response_obj = await llm_model.predict_async(prompt, **optional_params)
|
||||||
completion_response = response_obj.text
|
completion_response = response_obj.text
|
||||||
|
elif mode == "custom":
|
||||||
|
"""
|
||||||
|
Vertex AI Model Garden
|
||||||
|
"""
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
|
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||||
|
client_options=client_options
|
||||||
|
)
|
||||||
|
llm_model = async_client.endpoint_path(
|
||||||
|
project=vertex_project, location=vertex_location, endpoint=model
|
||||||
|
)
|
||||||
|
|
||||||
|
request_str += (
|
||||||
|
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||||
|
)
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response_obj = await async_client.predict(
|
||||||
|
endpoint=llm_model,
|
||||||
|
instances=instances,
|
||||||
|
)
|
||||||
|
response = response_obj.predictions
|
||||||
|
completion_response = response[0]
|
||||||
|
if (
|
||||||
|
isinstance(completion_response, str)
|
||||||
|
and "\nOutput:\n" in completion_response
|
||||||
|
):
|
||||||
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=prompt, api_key=None, original_response=completion_response
|
input=prompt, api_key=None, original_response=completion_response
|
||||||
|
@ -657,14 +783,12 @@ async def async_completion(
|
||||||
# init prompt tokens
|
# init prompt tokens
|
||||||
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
|
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
|
||||||
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
||||||
if response_obj is not None:
|
if response_obj is not None and (
|
||||||
if hasattr(response_obj, "usage_metadata") and hasattr(
|
hasattr(response_obj, "usage_metadata")
|
||||||
response_obj.usage_metadata, "prompt_token_count"
|
and hasattr(response_obj.usage_metadata, "prompt_token_count")
|
||||||
):
|
):
|
||||||
prompt_tokens = response_obj.usage_metadata.prompt_token_count
|
prompt_tokens = response_obj.usage_metadata.prompt_token_count
|
||||||
completion_tokens = (
|
completion_tokens = response_obj.usage_metadata.candidates_token_count
|
||||||
response_obj.usage_metadata.candidates_token_count
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
completion_tokens = len(
|
completion_tokens = len(
|
||||||
|
@ -693,8 +817,13 @@ async def async_streaming(
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
request_str=None,
|
request_str=None,
|
||||||
|
encoding=None,
|
||||||
messages=None,
|
messages=None,
|
||||||
print_verbose=None,
|
print_verbose=None,
|
||||||
|
client_options=None,
|
||||||
|
instances=None,
|
||||||
|
vertex_project=None,
|
||||||
|
vertex_location=None,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -763,15 +892,47 @@ async def async_streaming(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
||||||
|
elif mode == "custom":
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
|
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||||
|
client_options=client_options
|
||||||
|
)
|
||||||
|
llm_model = async_client.endpoint_path(
|
||||||
|
project=vertex_project, location=vertex_location, endpoint=model
|
||||||
|
)
|
||||||
|
|
||||||
|
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response_obj = await async_client.predict(
|
||||||
|
endpoint=llm_model,
|
||||||
|
instances=instances,
|
||||||
|
)
|
||||||
|
response = response_obj.predictions
|
||||||
|
completion_response = response[0]
|
||||||
|
if (
|
||||||
|
isinstance(completion_response, str)
|
||||||
|
and "\nOutput:\n" in completion_response
|
||||||
|
):
|
||||||
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
|
response = TextStreamer(completion_response)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="vertex_ai",
|
custom_llm_provider="vertex_ai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
async for transformed_chunk in streamwrapper:
|
return streamwrapper
|
||||||
yield transformed_chunk
|
|
||||||
|
|
||||||
|
|
||||||
def embedding():
|
def embedding():
|
||||||
|
|
|
@ -279,6 +279,9 @@ def test_completion_azure_gpt4_vision():
|
||||||
except openai.RateLimitError as e:
|
except openai.RateLimitError as e:
|
||||||
print("got a rate liimt error", e)
|
print("got a rate liimt error", e)
|
||||||
pass
|
pass
|
||||||
|
except openai.APIStatusError as e:
|
||||||
|
print("got an api status error", e)
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -465,6 +465,7 @@ def test_completion_mistral_api_stream():
|
||||||
def test_completion_deep_infra_stream():
|
def test_completion_deep_infra_stream():
|
||||||
# deep infra currently includes role in the 2nd chunk
|
# deep infra currently includes role in the 2nd chunk
|
||||||
# waiting for them to make a fix on this
|
# waiting for them to make a fix on this
|
||||||
|
litellm.set_verbose = True
|
||||||
try:
|
try:
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
|
|
@ -6977,6 +6977,21 @@ def exception_type(
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
elif original_exception.status_code == 503:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise ServiceUnavailableError(
|
||||||
|
message=f"AzureException - {original_exception.message}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="azure",
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
|
elif original_exception.status_code == 504: # gateway timeout error
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise Timeout(
|
||||||
|
message=f"AzureException - {original_exception.message}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="azure",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise APIError(
|
raise APIError(
|
||||||
|
@ -8061,6 +8076,7 @@ class CustomStreamWrapper:
|
||||||
if len(original_chunk.choices) > 0:
|
if len(original_chunk.choices) > 0:
|
||||||
try:
|
try:
|
||||||
delta = dict(original_chunk.choices[0].delta)
|
delta = dict(original_chunk.choices[0].delta)
|
||||||
|
print_verbose(f"original delta: {delta}")
|
||||||
model_response.choices[0].delta = Delta(**delta)
|
model_response.choices[0].delta = Delta(**delta)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_response.choices[0].delta = Delta()
|
model_response.choices[0].delta = Delta()
|
||||||
|
@ -8069,6 +8085,7 @@ class CustomStreamWrapper:
|
||||||
model_response.system_fingerprint = (
|
model_response.system_fingerprint = (
|
||||||
original_chunk.system_fingerprint
|
original_chunk.system_fingerprint
|
||||||
)
|
)
|
||||||
|
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||||
if self.sent_first_chunk == False:
|
if self.sent_first_chunk == False:
|
||||||
model_response.choices[0].delta["role"] = "assistant"
|
model_response.choices[0].delta["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue