feat(vertex_httpx.py): Moving to call vertex ai via httpx (instead of their sdk). Allows us to support all their api updates.

This commit is contained in:
Krrish Dholakia 2024-06-12 16:47:00 -07:00
parent fb96f07ccb
commit 3b913443fe
8 changed files with 431 additions and 40 deletions

View file

@ -93,7 +93,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
response.choices[0], litellm.utils.Choices
):
for word in self.banned_keywords_list:
self.test_violation(test_str=response.choices[0].message.content)
self.test_violation(test_str=response.choices[0].message.content or "")
async def async_post_call_streaming_hook(
self,

View file

@ -12,7 +12,12 @@ from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_result,
convert_to_gemini_tool_call_invoke,
)
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
from litellm.types.files import (
get_file_mime_type_for_file_type,
get_file_type_from_extension,
is_gemini_1_5_accepted_file_type,
is_video_file_type,
)
class VertexAIError(Exception):
@ -611,7 +616,7 @@ def completion(
llm_model = None
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
if acompletion == True:
if acompletion is True:
data = {
"llm_model": llm_model,
"mode": mode,
@ -643,7 +648,7 @@ def completion(
tools = optional_params.pop("tools", None)
content = _gemini_convert_messages_with_history(messages=messages)
stream = optional_params.pop("stream", False)
if stream == True:
if stream is True:
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(
input=prompt,

View file

@ -9,6 +9,14 @@ import litellm, uuid
import httpx, inspect # type: ignore
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
from litellm.types.llms.vertex_ai import (
ContentType,
SystemInstructions,
PartType,
RequestBody,
GenerateContentResponseBody,
)
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
class VertexAIError(Exception):
@ -33,16 +41,110 @@ class VertexLLM(BaseLLM):
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None
def load_auth(self) -> Tuple[Any, str]:
def _process_response(
self,
model: str,
response: httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
except Exception as e:
raise VertexAIError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
response.text, str(e)
),
status_code=422,
)
model_response.choices = []
## GET MODEL ##
model_response.model = model
## GET TEXT ##
for idx, candidate in enumerate(completion_response["candidates"]):
if candidate.get("content", None) is None:
continue
message = litellm.Message(
content=candidate["content"]["parts"][0]["text"],
role="assistant",
logprobs=None,
function_call=None,
tool_calls=None,
)
choice = litellm.Choices(
finish_reason=candidate.get("finishReason", "stop"),
index=candidate.get("index", idx),
message=message,
logprobs=None,
enhancements=None,
)
model_response.choices.append(choice)
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
completion_tokens=completion_response["usageMetadata"][
"candidatesTokenCount"
],
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
)
setattr(model_response, "usage", usage)
return model_response
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
return vertex_region or "us-central1"
def load_auth(
self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[Any, str]:
from google.auth.transport.requests import Request # type: ignore[import-untyped]
from google.auth.credentials import Credentials # type: ignore[import-untyped]
import google.auth as google_auth
credentials, project_id = google_auth.default(
if credentials is not None and isinstance(credentials, str):
import google.oauth2.service_account
json_obj = json.loads(credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
credentials.refresh(Request())
if project_id is None:
project_id = creds.project_id
else:
creds, project_id = google_auth.default(
quota_project_id=project_id,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
creds.refresh(Request())
if not project_id:
raise ValueError("Could not resolve project_id")
@ -52,38 +154,135 @@ class VertexLLM(BaseLLM):
f"Expected project_id to be a str but got {type(project_id)}"
)
return credentials, project_id
return creds, project_id
def refresh_auth(self, credentials: Any) -> None:
from google.auth.transport.requests import Request # type: ignore[import-untyped]
credentials.refresh(Request())
def _prepare_request(self, request: httpx.Request) -> None:
access_token = self._ensure_access_token()
if request.headers.get("Authorization"):
# already authenticated, nothing for us to do
return
request.headers["Authorization"] = f"Bearer {access_token}"
def _ensure_access_token(self) -> str:
if self.access_token is not None:
return self.access_token
def _ensure_access_token(
self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[str, str]:
"""
Returns auth token and project id
"""
if self.access_token is not None and self.project_id is not None:
return self.access_token, self.project_id
if not self._credentials:
self._credentials, project_id = self.load_auth()
self._credentials, project_id = self.load_auth(
credentials=credentials, project_id=project_id
)
if not self.project_id:
self.project_id = project_id
else:
self.refresh_auth(self._credentials)
if not self._credentials.token:
if not self.project_id:
self.project_id = self._credentials.project_id
if not self.project_id:
raise ValueError("Could not resolve project_id")
if not self._credentials or not self._credentials.token:
raise RuntimeError("Could not resolve API token from the environment")
assert isinstance(self._credentials.token, str)
return self._credentials.token
return self._credentials.token, self.project_id
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
stream = optional_params.pop("stream", None)
### SET RUNTIME ENDPOINT ###
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
## TRANSFORMATION ##
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[PartType] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block = PartType(text=message["content"])
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
system_instructions = SystemInstructions(parts=system_content_blocks)
content = _gemini_convert_messages_with_history(messages=messages)
data = RequestBody(system_instruction=system_instructions, contents=content)
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": url,
"headers": headers,
},
)
## COMPLETION CALL ##
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = HTTPHandler(**_params) # type: ignore
else:
client = client
try:
response = client.post(url=url, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
return self._process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data, # type: ignore
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
def image_generation(
self,
@ -163,7 +362,7 @@ class VertexLLM(BaseLLM):
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header = self._ensure_access_token()
auth_header, _ = self._ensure_access_token(credentials=None, project_id=None)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params

View file

@ -1893,6 +1893,7 @@ def completion(
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
new_params = deepcopy(optional_params)
if "claude-3" in model:
model_response = vertex_ai_anthropic.completion(
@ -1910,6 +1911,26 @@ def completion(
logging_obj=logging,
acompletion=acompletion,
)
elif (
model in litellm.vertex_language_models
or model in litellm.vertex_vision_models
):
model_response = vertex_chat_completion.completion( # type: ignore
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
)
else:
model_response = vertex_ai.completion(
model=model,

View file

@ -140,7 +140,7 @@ class _PROXY_AzureContentSafety(
response.choices[0], litellm.utils.Choices
):
await self.test_violation(
content=response.choices[0].message.content, source="output"
content=response.choices[0].message.content or "", source="output"
)
# async def async_post_call_streaming_hook(

View file

@ -532,6 +532,8 @@ def test_gemini_pro_vision():
# DO Not DELETE this ASSERT
# Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response
assert prompt_tokens == 263 # the gemini api returns 263 to us
# assert False
except litellm.RateLimitError as e:
pass
except Exception as e:

View file

@ -9,6 +9,7 @@ from typing_extensions import (
runtime_checkable,
Required,
)
from enum import Enum
class Field(TypedDict):
@ -51,3 +52,161 @@ class PartType(TypedDict, total=False):
class ContentType(TypedDict, total=False):
role: Literal["user", "model"]
parts: Required[List[PartType]]
class SystemInstructions(TypedDict):
parts: Required[List[PartType]]
class Schema(TypedDict, total=False):
type: Literal["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"]
description: str
enum: List[str]
items: List["Schema"]
properties: "Schema"
required: List[str]
nullable: bool
class FunctionDeclaration(TypedDict, total=False):
name: Required[str]
description: str
parameters: Schema
response: Schema
class FunctionCallingConfig(TypedDict, total=False):
mode: Literal["ANY", "AUTO", "NONE"]
allowed_function_names: List[str]
HarmCategory = Literal[
"HARM_CATEGORY_UNSPECIFIED",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
]
HarmBlockThreshold = Literal[
"HARM_BLOCK_THRESHOLD_UNSPECIFIED",
"BLOCK_LOW_AND_ABOVE",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_ONLY_HIGH",
"BLOCK_NONE",
]
HarmBlockMethod = Literal["HARM_BLOCK_METHOD_UNSPECIFIED", "SEVERITY", "PROBABILITY"]
HarmProbability = Literal[
"HARM_PROBABILITY_UNSPECIFIED", "NEGLIGIBLE", "LOW", "MEDIUM", "HIGH"
]
HarmSeverity = Literal[
"HARM_SEVERITY_UNSPECIFIED",
"HARM_SEVERITY_NEGLIGIBLE",
"HARM_SEVERITY_LOW",
"HARM_SEVERITY_MEDIUM",
"HARM_SEVERITY_HIGH",
]
class SafetSettingsConfig(TypedDict, total=False):
category: HarmCategory
threshold: HarmBlockThreshold
max_influential_terms: int
method: HarmBlockMethod
class GenerationConfig(TypedDict, total=False):
temperature: float
top_p: float
top_k: float
candidate_count: int
max_output_tokens: int
stop_sequences: List[str]
presence_penalty: float
frequency_penalty: float
response_mime_type: Literal["text/plain", "application/json"]
class RequestBody(TypedDict, total=False):
contents: Required[List[ContentType]]
system_instruction: SystemInstructions
tools: FunctionDeclaration
tool_config: FunctionCallingConfig
safety_settings: SafetSettingsConfig
generation_config: GenerationConfig
class SafetyRatings(TypedDict):
category: HarmCategory
probability: HarmProbability
probabilityScore: int
severity: HarmSeverity
blocked: bool
class Date(TypedDict):
year: int
month: int
date: int
class Citation(TypedDict):
startIndex: int
endIndex: int
uri: str
title: str
license: str
publicationDate: Date
class CitationMetadata(TypedDict):
citations: List[Citation]
class SearchEntryPoint(TypedDict, total=False):
renderedContent: str
sdkBlob: str
class GroundingMetadata(TypedDict, total=False):
webSearchQueries: List[str]
searchEntryPoint: SearchEntryPoint
class Candidates(TypedDict, total=False):
index: int
content: ContentType
finishReason: Literal[
"FINISH_REASON_UNSPECIFIED",
"STOP",
"MAX_TOKENS",
"SAFETY",
"RECITATION",
"OTHER",
"BLOCKLIST",
"PROHIBITED_CONTENT",
"SPII",
]
safetyRatings: SafetyRatings
citationMetadata: CitationMetadata
groundingMetadata: GroundingMetadata
finishMessage: str
class PromptFeedback(TypedDict):
blockReason: str
safetyRatings: List[SafetyRatings]
blockReasonMessage: str
class UsageMetadata(TypedDict):
promptTokenCount: int
totalTokenCount: int
candidatesTokenCount: int
class GenerateContentResponseBody(TypedDict, total=False):
candidates: Required[List[Candidates]]
promptFeedback: PromptFeedback
usageMetadata: Required[UsageMetadata]

View file

@ -518,15 +518,18 @@ class Choices(OpenAIObject):
self,
finish_reason=None,
index=0,
message=None,
message: Optional[Union[Message, dict]] = None,
logprobs=None,
enhancements=None,
**params,
):
super(Choices, self).__init__(**params)
self.finish_reason = (
map_finish_reason(finish_reason) or "stop"
if finish_reason is not None:
self.finish_reason = map_finish_reason(
finish_reason
) # set finish_reason for all responses
else:
self.finish_reason = "stop"
self.index = index
if message is None:
self.message = Message()
@ -2822,7 +2825,9 @@ class Rules:
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore
return True
def post_call_rules(self, input: str, model: str):
def post_call_rules(self, input: Optional[str], model: str) -> bool:
if input is None:
return True
for rule in litellm.post_call_rules:
if callable(rule):
decision = rule(input)
@ -3101,9 +3106,9 @@ def client(original_function):
pass
else:
if isinstance(original_response, ModelResponse):
model_response = original_response["choices"][0]["message"][
"content"
]
model_response = original_response.choices[
0
].message.content
### POST-CALL RULES ###
rules_obj.post_call_rules(input=model_response, model=model)
except Exception as e: