fix(vertex_httpx.py): add sync vertex image gen support

Fixes https://github.com/BerriAI/litellm/issues/4623
This commit is contained in:
Krrish Dholakia 2024-07-09 13:33:54 -07:00
parent 5f279c937c
commit a1986fab60
5 changed files with 151 additions and 26 deletions

View file

@ -1188,23 +1188,25 @@ class VertexLLM(BaseLLM):
def image_generation(
self,
prompt: str,
vertex_project: str,
vertex_location: str,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
model_response: litellm.ImageResponse,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
client: Optional[Any] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
model_response=None,
aimg_generation=False,
):
if aimg_generation == True:
response = self.aimage_generation(
if aimg_generation is True:
return self.aimage_generation(
prompt=prompt,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
model=model,
client=client,
optional_params=optional_params,
@ -1212,13 +1214,99 @@ class VertexLLM(BaseLLM):
logging_obj=logging_obj,
model_response=model_response,
)
return response
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = sync_handler.post(
url=url,
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
"""
Vertex AI Image generation response example:
{
"predictions": [
{
"bytesBase64Encoded": "BASE64_IMG_BYTES",
"mimeType": "image/png"
},
{
"mimeType": "image/png",
"bytesBase64Encoded": "BASE64_IMG_BYTES"
}
]
}
"""
_json_response = response.json()
_predictions = _json_response["predictions"]
_response_data: List[litellm.ImageObject] = []
for _prediction in _predictions:
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
_response_data.append(image_object)
model_response.data = _response_data
return model_response
async def aimage_generation(
self,
prompt: str,
vertex_project: str,
vertex_location: str,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
model_response: litellm.ImageResponse,
model: Optional[
str
@ -1263,7 +1351,9 @@ class VertexLLM(BaseLLM):
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header, _ = self._ensure_access_token(credentials=None, project_id=None)
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params

View file

@ -4263,6 +4263,7 @@ def image_generation(
model_response=model_response,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aimg_generation=aimg_generation,
)

View file

@ -1,5 +1,5 @@
model_list:
- model_name: tts
- model_name: "*"
litellm_params:
model: "openai/*"
- model_name: gemini-1.5-flash
@ -19,4 +19,3 @@ model_list:
general_settings:
alerting: ["slack"]
alerting_threshold: 10
allowed_ips: ["192.168.1.1"]

View file

@ -1,16 +1,19 @@
#### What this tests ####
# This tests the the acompletion function #
import sys, os
import pytest
import asyncio
import logging
import os
import sys
import traceback
import asyncio, logging
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import completion, acompletion, acreate
from litellm import acompletion, acreate, completion
litellm.num_retries = 3
@ -42,9 +45,36 @@ def test_async_response_openai():
async def test_get_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
try:
response = await acompletion(
model="gpt-3.5-turbo", messages=messages, timeout=5
model="gpt-3.5-turbo",
messages=messages,
tools=tools,
parallel_tool_calls=True,
timeout=5,
)
print(f"response: {response}")
print(f"response ms: {response._response_ms}")

View file

@ -190,21 +190,26 @@ async def test_aimage_generation_bedrock_with_optional_params():
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_aimage_generation_vertex_ai():
async def test_aimage_generation_vertex_ai(sync_mode):
from test_amazing_vertex_completion import load_vertex_ai_credentials
litellm.set_verbose = True
load_vertex_ai_credentials()
data = {
"prompt": "An olympic size swimming pool",
"model": "vertex_ai/imagegeneration@006",
"vertex_ai_project": "adroit-crow-413218",
"vertex_ai_location": "us-central1",
"n": 1,
}
try:
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
if sync_mode:
response = litellm.image_generation(**data)
else:
response = await litellm.aimage_generation(**data)
assert response.data is not None
assert len(response.data) > 0