mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
feat - working httpx requests vertex ai image gen
This commit is contained in:
parent
064a72f71f
commit
884e2beed6
3 changed files with 188 additions and 1 deletions
156
litellm/llms/vertex_httpx.py
Normal file
156
litellm/llms/vertex_httpx.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, Union, List
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||
import litellm, uuid
|
||||
import httpx, inspect # type: ignore
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from .base import BaseLLM
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class VertexLLM(BaseLLM):
|
||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||
|
||||
def __init__(self) -> None:
|
||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||
|
||||
super().__init__()
|
||||
self.access_token: Optional[str] = None
|
||||
self.refresh_token: Optional[str] = None
|
||||
self._credentials: Optional[Credentials] = None
|
||||
self.project_id: Optional[str] = None
|
||||
|
||||
def load_auth(self) -> tuple[Credentials, str]:
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||
import google.auth as google_auth
|
||||
|
||||
credentials, project_id = google_auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
if not project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
if not isinstance(project_id, str):
|
||||
raise TypeError(
|
||||
f"Expected project_id to be a str but got {type(project_id)}"
|
||||
)
|
||||
|
||||
return credentials, project_id
|
||||
|
||||
def refresh_auth(self, credentials: Credentials) -> None:
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
def _prepare_request(self, request: httpx.Request) -> None:
|
||||
access_token = self._ensure_access_token()
|
||||
|
||||
if request.headers.get("Authorization"):
|
||||
# already authenticated, nothing for us to do
|
||||
return
|
||||
|
||||
request.headers["Authorization"] = f"Bearer {access_token}"
|
||||
|
||||
def _ensure_access_token(self) -> str:
|
||||
if self.access_token is not None:
|
||||
return self.access_token
|
||||
|
||||
if not self._credentials:
|
||||
self._credentials, project_id = self.load_auth()
|
||||
if not self.project_id:
|
||||
self.project_id = project_id
|
||||
else:
|
||||
self.refresh_auth(self._credentials)
|
||||
|
||||
if not self._credentials.token:
|
||||
raise RuntimeError("Could not resolve API token from the environment")
|
||||
|
||||
assert isinstance(self._credentials.token, str)
|
||||
return self._credentials.token
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
vertex_project: str,
|
||||
vertex_location: str,
|
||||
model: Optional[
|
||||
str
|
||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
):
|
||||
response = None
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
# make POST request to
|
||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||
|
||||
"""
|
||||
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
-d {
|
||||
"instances": [
|
||||
{
|
||||
"prompt": "a cat"
|
||||
}
|
||||
]
|
||||
} \
|
||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||
"""
|
||||
|
||||
import vertexai
|
||||
|
||||
auth_header = self._ensure_access_token()
|
||||
|
||||
request_data = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": {"sampleCount": 1},
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
url=url,
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
},
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
return model_response
|
|
@ -79,6 +79,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
|
|||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.bedrock_httpx import BedrockLLM
|
||||
from .llms.vertex_httpx import VertexLLM
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
|
@ -118,6 +119,7 @@ huggingface = Huggingface()
|
|||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -3854,6 +3856,35 @@ def image_generation(
|
|||
model_response=model_response,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
or optional_params.pop("vertex_ai_project", None)
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.pop("vertex_location", None)
|
||||
or optional_params.pop("vertex_ai_location", None)
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = (
|
||||
optional_params.pop("vertex_credentials", None)
|
||||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
model_response = vertex_chat_completion.aimage_generation( # type: ignore
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
logging_obj=litellm_logging_obj,
|
||||
optional_params=optional_params,
|
||||
model_response=model_response,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
)
|
||||
|
||||
return model_response
|
||||
except Exception as e:
|
||||
## Map to OpenAI Exception
|
||||
|
|
|
@ -175,7 +175,7 @@ async def test_aimage_generation_bedrock_with_optional_params():
|
|||
async def test_aimage_generation_vertex_ai():
|
||||
try:
|
||||
response = await litellm.aimage_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
prompt="An olympic size swimming pool",
|
||||
model="vertex_ai/imagegeneration@006",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue