fix(vertex_httpx.py): add sync vertex image gen support

Fixes https://github.com/BerriAI/litellm/issues/4623
This commit is contained in:
Krrish Dholakia 2024-07-09 13:33:54 -07:00
parent 5f279c937c
commit a1986fab60
5 changed files with 151 additions and 26 deletions

View file

@ -1188,23 +1188,25 @@ class VertexLLM(BaseLLM):
def image_generation( def image_generation(
self, self,
prompt: str, prompt: str,
vertex_project: str, vertex_project: Optional[str],
vertex_location: str, vertex_location: Optional[str],
vertex_credentials: Optional[str],
model_response: litellm.ImageResponse,
model: Optional[ model: Optional[
str str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None, client: Optional[Any] = None,
optional_params: Optional[dict] = None, optional_params: Optional[dict] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
logging_obj=None, logging_obj=None,
model_response=None,
aimg_generation=False, aimg_generation=False,
): ):
if aimg_generation == True: if aimg_generation is True:
response = self.aimage_generation( return self.aimage_generation(
prompt=prompt, prompt=prompt,
vertex_project=vertex_project, vertex_project=vertex_project,
vertex_location=vertex_location, vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
model=model, model=model,
client=client, client=client,
optional_params=optional_params, optional_params=optional_params,
@ -1212,13 +1214,99 @@ class VertexLLM(BaseLLM):
logging_obj=logging_obj, logging_obj=logging_obj,
model_response=model_response, model_response=model_response,
) )
return response
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = sync_handler.post(
url=url,
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
"""
Vertex AI Image generation response example:
{
"predictions": [
{
"bytesBase64Encoded": "BASE64_IMG_BYTES",
"mimeType": "image/png"
},
{
"mimeType": "image/png",
"bytesBase64Encoded": "BASE64_IMG_BYTES"
}
]
}
"""
_json_response = response.json()
_predictions = _json_response["predictions"]
_response_data: List[litellm.ImageObject] = []
for _prediction in _predictions:
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
_response_data.append(image_object)
model_response.data = _response_data
return model_response
async def aimage_generation( async def aimage_generation(
self, self,
prompt: str, prompt: str,
vertex_project: str, vertex_project: Optional[str],
vertex_location: str, vertex_location: Optional[str],
vertex_credentials: Optional[str],
model_response: litellm.ImageResponse, model_response: litellm.ImageResponse,
model: Optional[ model: Optional[
str str
@ -1263,7 +1351,9 @@ class VertexLLM(BaseLLM):
} \ } \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
""" """
auth_header, _ = self._ensure_access_token(credentials=None, project_id=None) auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or { optional_params = optional_params or {
"sampleCount": 1 "sampleCount": 1
} # default optional params } # default optional params

View file

@ -4263,6 +4263,7 @@ def image_generation(
model_response=model_response, model_response=model_response,
vertex_project=vertex_ai_project, vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location, vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aimg_generation=aimg_generation, aimg_generation=aimg_generation,
) )

View file

@ -1,5 +1,5 @@
model_list: model_list:
- model_name: tts - model_name: "*"
litellm_params: litellm_params:
model: "openai/*" model: "openai/*"
- model_name: gemini-1.5-flash - model_name: gemini-1.5-flash
@ -18,5 +18,4 @@ model_list:
general_settings: general_settings:
alerting: ["slack"] alerting: ["slack"]
alerting_threshold: 10 alerting_threshold: 10
allowed_ips: ["192.168.1.1"]

View file

@ -1,16 +1,19 @@
#### What this tests #### #### What this tests ####
# This tests the the acompletion function # # This tests the the acompletion function #
import sys, os import asyncio
import pytest import logging
import os
import sys
import traceback import traceback
import asyncio, logging
import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
from litellm import completion, acompletion, acreate from litellm import acompletion, acreate, completion
litellm.num_retries = 3 litellm.num_retries = 3
@ -42,9 +45,36 @@ def test_async_response_openai():
async def test_get_response(): async def test_get_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
try: try:
response = await acompletion( response = await acompletion(
model="gpt-3.5-turbo", messages=messages, timeout=5 model="gpt-3.5-turbo",
messages=messages,
tools=tools,
parallel_tool_calls=True,
timeout=5,
) )
print(f"response: {response}") print(f"response: {response}")
print(f"response ms: {response._response_ms}") print(f"response ms: {response._response_ms}")

View file

@ -190,21 +190,26 @@ async def test_aimage_generation_bedrock_with_optional_params():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_aimage_generation_vertex_ai(): async def test_aimage_generation_vertex_ai(sync_mode):
from test_amazing_vertex_completion import load_vertex_ai_credentials from test_amazing_vertex_completion import load_vertex_ai_credentials
litellm.set_verbose = True litellm.set_verbose = True
load_vertex_ai_credentials() load_vertex_ai_credentials()
data = {
"prompt": "An olympic size swimming pool",
"model": "vertex_ai/imagegeneration@006",
"vertex_ai_project": "adroit-crow-413218",
"vertex_ai_location": "us-central1",
"n": 1,
}
try: try:
response = await litellm.aimage_generation( if sync_mode:
prompt="An olympic size swimming pool", response = litellm.image_generation(**data)
model="vertex_ai/imagegeneration@006", else:
vertex_ai_project="adroit-crow-413218", response = await litellm.aimage_generation(**data)
vertex_ai_location="us-central1",
n=1,
)
assert response.data is not None assert response.data is not None
assert len(response.data) > 0 assert len(response.data) > 0