mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
616 lines
21 KiB
Python
616 lines
21 KiB
Python
# What is this?
|
|
## httpx client for vertex ai calls
|
|
## Initial implementation - covers gemini + image gen calls
|
|
from functools import partial
|
|
import os, types
|
|
import json
|
|
from enum import Enum
|
|
import requests # type: ignore
|
|
import time
|
|
from typing import Callable, Optional, Union, List, Any, Tuple
|
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
|
import litellm, uuid
|
|
import httpx, inspect # type: ignore
|
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
from .base import BaseLLM
|
|
from litellm.types.llms.vertex_ai import (
|
|
ContentType,
|
|
SystemInstructions,
|
|
PartType,
|
|
RequestBody,
|
|
GenerateContentResponseBody,
|
|
)
|
|
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
|
from litellm.types.utils import GenericStreamingChunk
|
|
from litellm.types.llms.openai import (
|
|
ChatCompletionUsageBlock,
|
|
ChatCompletionToolCallChunk,
|
|
ChatCompletionToolCallFunctionChunk,
|
|
)
|
|
|
|
|
|
class VertexGeminiConfig:
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def supports_system_message(self) -> bool:
|
|
"""
|
|
Not all gemini models support system instructions
|
|
"""
|
|
return True
|
|
|
|
|
|
async def make_call(
|
|
client: Optional[AsyncHTTPHandler],
|
|
api_base: str,
|
|
headers: dict,
|
|
data: str,
|
|
model: str,
|
|
messages: list,
|
|
logging_obj,
|
|
):
|
|
if client is None:
|
|
client = AsyncHTTPHandler() # Create a new client if none provided
|
|
|
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
|
|
|
if response.status_code != 200:
|
|
raise VertexAIError(status_code=response.status_code, message=response.text)
|
|
|
|
completion_stream = ModelResponseIterator(
|
|
streaming_response=response.aiter_bytes(chunk_size=2056)
|
|
)
|
|
# LOGGING
|
|
logging_obj.post_call(
|
|
input=messages,
|
|
api_key="",
|
|
original_response="first stream response received",
|
|
additional_args={"complete_input_dict": data},
|
|
)
|
|
|
|
return completion_stream
|
|
|
|
|
|
def make_sync_call(
|
|
client: Optional[HTTPHandler],
|
|
api_base: str,
|
|
headers: dict,
|
|
data: str,
|
|
model: str,
|
|
messages: list,
|
|
logging_obj,
|
|
):
|
|
if client is None:
|
|
client = HTTPHandler() # Create a new client if none provided
|
|
|
|
response = client.post(api_base, headers=headers, data=data, stream=True)
|
|
|
|
if response.status_code != 200:
|
|
raise VertexAIError(status_code=response.status_code, message=response.read())
|
|
|
|
completion_stream = ModelResponseIterator(
|
|
streaming_response=response.iter_bytes(chunk_size=2056)
|
|
)
|
|
|
|
# LOGGING
|
|
logging_obj.post_call(
|
|
input=messages,
|
|
api_key="",
|
|
original_response="first stream response received",
|
|
additional_args={"complete_input_dict": data},
|
|
)
|
|
|
|
return completion_stream
|
|
|
|
|
|
class VertexAIError(Exception):
|
|
def __init__(self, status_code, message):
|
|
self.status_code = status_code
|
|
self.message = message
|
|
self.request = httpx.Request(
|
|
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
|
)
|
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
super().__init__(
|
|
self.message
|
|
) # Call the base class constructor with the parameters it needs
|
|
|
|
|
|
class VertexLLM(BaseLLM):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.access_token: Optional[str] = None
|
|
self.refresh_token: Optional[str] = None
|
|
self._credentials: Optional[Any] = None
|
|
self.project_id: Optional[str] = None
|
|
self.async_handler: Optional[AsyncHTTPHandler] = None
|
|
|
|
def _process_response(
|
|
self,
|
|
model: str,
|
|
response: httpx.Response,
|
|
model_response: ModelResponse,
|
|
logging_obj: litellm.utils.Logging,
|
|
optional_params: dict,
|
|
api_key: str,
|
|
data: Union[dict, str],
|
|
messages: List,
|
|
print_verbose,
|
|
encoding,
|
|
) -> ModelResponse:
|
|
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=messages,
|
|
api_key="",
|
|
original_response=response.text,
|
|
additional_args={"complete_input_dict": data},
|
|
)
|
|
|
|
print_verbose(f"raw model_response: {response.text}")
|
|
|
|
## RESPONSE OBJECT
|
|
try:
|
|
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
|
|
except Exception as e:
|
|
raise VertexAIError(
|
|
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
|
response.text, str(e)
|
|
),
|
|
status_code=422,
|
|
)
|
|
|
|
model_response.choices = [] # type: ignore
|
|
|
|
## GET MODEL ##
|
|
model_response.model = model
|
|
## GET TEXT ##
|
|
for idx, candidate in enumerate(completion_response["candidates"]):
|
|
if candidate.get("content", None) is None:
|
|
continue
|
|
|
|
message = litellm.Message(
|
|
content=candidate["content"]["parts"][0]["text"],
|
|
role="assistant",
|
|
logprobs=None,
|
|
function_call=None,
|
|
tool_calls=None,
|
|
)
|
|
choice = litellm.Choices(
|
|
finish_reason=candidate.get("finishReason", "stop"),
|
|
index=candidate.get("index", idx),
|
|
message=message,
|
|
logprobs=None,
|
|
enhancements=None,
|
|
)
|
|
|
|
model_response.choices.append(choice)
|
|
|
|
## GET USAGE ##
|
|
usage = litellm.Usage(
|
|
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
|
|
completion_tokens=completion_response["usageMetadata"][
|
|
"candidatesTokenCount"
|
|
],
|
|
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
|
|
)
|
|
|
|
setattr(model_response, "usage", usage)
|
|
|
|
return model_response
|
|
|
|
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
|
return vertex_region or "us-central1"
|
|
|
|
def load_auth(
|
|
self, credentials: Optional[str], project_id: Optional[str]
|
|
) -> Tuple[Any, str]:
|
|
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
|
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
|
import google.auth as google_auth
|
|
|
|
if credentials is not None and isinstance(credentials, str):
|
|
import google.oauth2.service_account
|
|
|
|
json_obj = json.loads(credentials)
|
|
|
|
creds = google.oauth2.service_account.Credentials.from_service_account_info(
|
|
json_obj,
|
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
)
|
|
|
|
if project_id is None:
|
|
project_id = creds.project_id
|
|
else:
|
|
creds, project_id = google_auth.default(
|
|
quota_project_id=project_id,
|
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
)
|
|
|
|
creds.refresh(Request())
|
|
|
|
if not project_id:
|
|
raise ValueError("Could not resolve project_id")
|
|
|
|
if not isinstance(project_id, str):
|
|
raise TypeError(
|
|
f"Expected project_id to be a str but got {type(project_id)}"
|
|
)
|
|
|
|
return creds, project_id
|
|
|
|
def refresh_auth(self, credentials: Any) -> None:
|
|
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
|
|
|
credentials.refresh(Request())
|
|
|
|
def _ensure_access_token(
|
|
self, credentials: Optional[str], project_id: Optional[str]
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Returns auth token and project id
|
|
"""
|
|
if self.access_token is not None and self.project_id is not None:
|
|
return self.access_token, self.project_id
|
|
|
|
if not self._credentials:
|
|
self._credentials, project_id = self.load_auth(
|
|
credentials=credentials, project_id=project_id
|
|
)
|
|
if not self.project_id:
|
|
self.project_id = project_id
|
|
else:
|
|
self.refresh_auth(self._credentials)
|
|
|
|
if not self.project_id:
|
|
self.project_id = self._credentials.project_id
|
|
|
|
if not self.project_id:
|
|
raise ValueError("Could not resolve project_id")
|
|
|
|
if not self._credentials or not self._credentials.token:
|
|
raise RuntimeError("Could not resolve API token from the environment")
|
|
|
|
return self._credentials.token, self.project_id
|
|
|
|
async def async_streaming(
|
|
self,
|
|
):
|
|
pass
|
|
|
|
async def async_completion(
|
|
self,
|
|
):
|
|
pass
|
|
|
|
def completion(
|
|
self,
|
|
model: str,
|
|
messages: list,
|
|
model_response: ModelResponse,
|
|
print_verbose: Callable,
|
|
encoding,
|
|
logging_obj,
|
|
optional_params: dict,
|
|
acompletion: bool,
|
|
timeout: Optional[Union[float, httpx.Timeout]],
|
|
vertex_project: Optional[str],
|
|
vertex_location: Optional[str],
|
|
vertex_credentials: Optional[str],
|
|
litellm_params=None,
|
|
logger_fn=None,
|
|
extra_headers: Optional[dict] = None,
|
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
|
|
|
auth_header, vertex_project = self._ensure_access_token(
|
|
credentials=vertex_credentials, project_id=vertex_project
|
|
)
|
|
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
|
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
|
|
|
### SET RUNTIME ENDPOINT ###
|
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
|
|
|
|
## TRANSFORMATION ##
|
|
# Separate system prompt from rest of message
|
|
system_prompt_indices = []
|
|
system_content_blocks: List[PartType] = []
|
|
for idx, message in enumerate(messages):
|
|
if message["role"] == "system":
|
|
_system_content_block = PartType(text=message["content"])
|
|
system_content_blocks.append(_system_content_block)
|
|
system_prompt_indices.append(idx)
|
|
if len(system_prompt_indices) > 0:
|
|
for idx in reversed(system_prompt_indices):
|
|
messages.pop(idx)
|
|
system_instructions = SystemInstructions(parts=system_content_blocks)
|
|
content = _gemini_convert_messages_with_history(messages=messages)
|
|
|
|
data = RequestBody(system_instruction=system_instructions, contents=content)
|
|
headers = {
|
|
"Content-Type": "application/json; charset=utf-8",
|
|
"Authorization": f"Bearer {auth_header}",
|
|
}
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=messages,
|
|
api_key="",
|
|
additional_args={
|
|
"complete_input_dict": data,
|
|
"api_base": url,
|
|
"headers": headers,
|
|
},
|
|
)
|
|
|
|
## SYNC STREAMING CALL ##
|
|
if stream is not None and stream is True:
|
|
streaming_response = CustomStreamWrapper(
|
|
completion_stream=None,
|
|
make_call=partial(
|
|
make_sync_call,
|
|
client=None,
|
|
api_base=url,
|
|
headers=headers, # type: ignore
|
|
data=json.dumps(data),
|
|
model=model,
|
|
messages=messages,
|
|
logging_obj=logging_obj,
|
|
),
|
|
model=model,
|
|
custom_llm_provider="bedrock",
|
|
logging_obj=logging_obj,
|
|
)
|
|
|
|
return streaming_response
|
|
## COMPLETION CALL ##
|
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
|
_params = {}
|
|
if timeout is not None:
|
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
|
timeout = httpx.Timeout(timeout)
|
|
_params["timeout"] = timeout
|
|
client = HTTPHandler(**_params) # type: ignore
|
|
else:
|
|
client = client
|
|
try:
|
|
response = client.post(url=url, headers=headers, json=data) # type: ignore
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as err:
|
|
error_code = err.response.status_code
|
|
raise VertexAIError(status_code=error_code, message=response.text)
|
|
except httpx.TimeoutException:
|
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
|
|
|
return self._process_response(
|
|
model=model,
|
|
response=response,
|
|
model_response=model_response,
|
|
logging_obj=logging_obj,
|
|
optional_params=optional_params,
|
|
api_key="",
|
|
data=data, # type: ignore
|
|
messages=messages,
|
|
print_verbose=print_verbose,
|
|
encoding=encoding,
|
|
)
|
|
|
|
def image_generation(
|
|
self,
|
|
prompt: str,
|
|
vertex_project: str,
|
|
vertex_location: str,
|
|
model: Optional[
|
|
str
|
|
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
|
client: Optional[AsyncHTTPHandler] = None,
|
|
optional_params: Optional[dict] = None,
|
|
timeout: Optional[int] = None,
|
|
logging_obj=None,
|
|
model_response=None,
|
|
aimg_generation=False,
|
|
):
|
|
if aimg_generation == True:
|
|
response = self.aimage_generation(
|
|
prompt=prompt,
|
|
vertex_project=vertex_project,
|
|
vertex_location=vertex_location,
|
|
model=model,
|
|
client=client,
|
|
optional_params=optional_params,
|
|
timeout=timeout,
|
|
logging_obj=logging_obj,
|
|
model_response=model_response,
|
|
)
|
|
return response
|
|
|
|
async def aimage_generation(
|
|
self,
|
|
prompt: str,
|
|
vertex_project: str,
|
|
vertex_location: str,
|
|
model_response: litellm.ImageResponse,
|
|
model: Optional[
|
|
str
|
|
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
|
client: Optional[AsyncHTTPHandler] = None,
|
|
optional_params: Optional[dict] = None,
|
|
timeout: Optional[int] = None,
|
|
logging_obj=None,
|
|
):
|
|
response = None
|
|
if client is None:
|
|
_params = {}
|
|
if timeout is not None:
|
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
|
_httpx_timeout = httpx.Timeout(timeout)
|
|
_params["timeout"] = _httpx_timeout
|
|
else:
|
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
|
|
|
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
|
|
else:
|
|
self.async_handler = client # type: ignore
|
|
|
|
# make POST request to
|
|
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
|
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
|
|
|
"""
|
|
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
|
curl -X POST \
|
|
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
|
|
-H "Content-Type: application/json; charset=utf-8" \
|
|
-d {
|
|
"instances": [
|
|
{
|
|
"prompt": "a cat"
|
|
}
|
|
],
|
|
"parameters": {
|
|
"sampleCount": 1
|
|
}
|
|
} \
|
|
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
|
"""
|
|
auth_header, _ = self._ensure_access_token(credentials=None, project_id=None)
|
|
optional_params = optional_params or {
|
|
"sampleCount": 1
|
|
} # default optional params
|
|
|
|
request_data = {
|
|
"instances": [{"prompt": prompt}],
|
|
"parameters": optional_params,
|
|
}
|
|
|
|
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
|
logging_obj.pre_call(
|
|
input=prompt,
|
|
api_key=None,
|
|
additional_args={
|
|
"complete_input_dict": optional_params,
|
|
"request_str": request_str,
|
|
},
|
|
)
|
|
|
|
response = await self.async_handler.post(
|
|
url=url,
|
|
headers={
|
|
"Content-Type": "application/json; charset=utf-8",
|
|
"Authorization": f"Bearer {auth_header}",
|
|
},
|
|
data=json.dumps(request_data),
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
|
"""
|
|
Vertex AI Image generation response example:
|
|
{
|
|
"predictions": [
|
|
{
|
|
"bytesBase64Encoded": "BASE64_IMG_BYTES",
|
|
"mimeType": "image/png"
|
|
},
|
|
{
|
|
"mimeType": "image/png",
|
|
"bytesBase64Encoded": "BASE64_IMG_BYTES"
|
|
}
|
|
]
|
|
}
|
|
"""
|
|
|
|
_json_response = response.json()
|
|
_predictions = _json_response["predictions"]
|
|
|
|
_response_data: List[litellm.ImageObject] = []
|
|
for _prediction in _predictions:
|
|
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
|
|
image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
|
|
_response_data.append(image_object)
|
|
|
|
model_response.data = _response_data
|
|
|
|
return model_response
|
|
|
|
|
|
class ModelResponseIterator:
|
|
def __init__(self, streaming_response):
|
|
self.streaming_response = streaming_response
|
|
self.response_iterator = iter(self.streaming_response)
|
|
|
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
|
try:
|
|
processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore
|
|
text = ""
|
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
|
is_finished = False
|
|
finish_reason = ""
|
|
usage: Optional[ChatCompletionUsageBlock] = None
|
|
|
|
gemini_chunk = processed_chunk["candidates"][0]
|
|
|
|
if (
|
|
"content" in gemini_chunk
|
|
and "text" in gemini_chunk["content"]["parts"][0]
|
|
):
|
|
text = gemini_chunk["content"]["parts"][0]["text"]
|
|
|
|
if "finishReason" in gemini_chunk:
|
|
finish_reason = map_finish_reason(
|
|
finish_reason=gemini_chunk["finishReason"]
|
|
)
|
|
is_finished = True
|
|
|
|
if "usageMetadata" in processed_chunk:
|
|
usage = ChatCompletionUsageBlock(
|
|
prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"],
|
|
completion_tokens=processed_chunk["usageMetadata"][
|
|
"candidatesTokenCount"
|
|
],
|
|
total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"],
|
|
)
|
|
|
|
returned_chunk = GenericStreamingChunk(
|
|
text=text,
|
|
tool_use=tool_use,
|
|
is_finished=is_finished,
|
|
finish_reason=finish_reason,
|
|
usage=usage,
|
|
index=0,
|
|
)
|
|
return returned_chunk
|
|
except json.JSONDecodeError:
|
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
|
|
|
# Sync iterator
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
try:
|
|
chunk = next(self.response_iterator)
|
|
chunk = chunk.decode()
|
|
json_chunk = json.loads(chunk)
|
|
return self.chunk_parser(chunk=json_chunk)
|
|
except StopIteration:
|
|
raise StopIteration
|
|
except ValueError as e:
|
|
raise RuntimeError(f"Error parsing chunk: {e}")
|
|
|
|
# Async iterator
|
|
def __aiter__(self):
|
|
self.async_response_iterator = self.streaming_response.__aiter__()
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
try:
|
|
chunk = await self.async_response_iterator.__anext__()
|
|
chunk = chunk.decode()
|
|
json_chunk = json.loads(chunk)
|
|
return self.chunk_parser(chunk=json_chunk)
|
|
except StopAsyncIteration:
|
|
raise StopAsyncIteration
|
|
except ValueError as e:
|
|
raise RuntimeError(f"Error parsing chunk: {e}")
|