forked from phoenix/litellm-mirror
feat(vertex_httpx.py): Moving to call vertex ai via httpx (instead of their sdk). Allows us to support all their api updates.
This commit is contained in:
parent
fb96f07ccb
commit
3b913443fe
8 changed files with 431 additions and 40 deletions
|
@ -93,7 +93,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
|
||||||
response.choices[0], litellm.utils.Choices
|
response.choices[0], litellm.utils.Choices
|
||||||
):
|
):
|
||||||
for word in self.banned_keywords_list:
|
for word in self.banned_keywords_list:
|
||||||
self.test_violation(test_str=response.choices[0].message.content)
|
self.test_violation(test_str=response.choices[0].message.content or "")
|
||||||
|
|
||||||
async def async_post_call_streaming_hook(
|
async def async_post_call_streaming_hook(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -12,7 +12,12 @@ from litellm.llms.prompt_templates.factory import (
|
||||||
convert_to_gemini_tool_call_result,
|
convert_to_gemini_tool_call_result,
|
||||||
convert_to_gemini_tool_call_invoke,
|
convert_to_gemini_tool_call_invoke,
|
||||||
)
|
)
|
||||||
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
|
from litellm.types.files import (
|
||||||
|
get_file_mime_type_for_file_type,
|
||||||
|
get_file_type_from_extension,
|
||||||
|
is_gemini_1_5_accepted_file_type,
|
||||||
|
is_video_file_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -611,7 +616,7 @@ def completion(
|
||||||
llm_model = None
|
llm_model = None
|
||||||
|
|
||||||
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
|
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
data = {
|
data = {
|
||||||
"llm_model": llm_model,
|
"llm_model": llm_model,
|
||||||
"mode": mode,
|
"mode": mode,
|
||||||
|
@ -643,7 +648,7 @@ def completion(
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
content = _gemini_convert_messages_with_history(messages=messages)
|
content = _gemini_convert_messages_with_history(messages=messages)
|
||||||
stream = optional_params.pop("stream", False)
|
stream = optional_params.pop("stream", False)
|
||||||
if stream == True:
|
if stream is True:
|
||||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
|
|
@ -9,6 +9,14 @@ import litellm, uuid
|
||||||
import httpx, inspect # type: ignore
|
import httpx, inspect # type: ignore
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
ContentType,
|
||||||
|
SystemInstructions,
|
||||||
|
PartType,
|
||||||
|
RequestBody,
|
||||||
|
GenerateContentResponseBody,
|
||||||
|
)
|
||||||
|
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -33,16 +41,110 @@ class VertexLLM(BaseLLM):
|
||||||
self.project_id: Optional[str] = None
|
self.project_id: Optional[str] = None
|
||||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||||
|
|
||||||
def load_auth(self) -> Tuple[Any, str]:
|
def _process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: litellm.utils.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(
|
||||||
|
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
response.text, str(e)
|
||||||
|
),
|
||||||
|
status_code=422,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.choices = []
|
||||||
|
|
||||||
|
## GET MODEL ##
|
||||||
|
model_response.model = model
|
||||||
|
## GET TEXT ##
|
||||||
|
for idx, candidate in enumerate(completion_response["candidates"]):
|
||||||
|
if candidate.get("content", None) is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
message = litellm.Message(
|
||||||
|
content=candidate["content"]["parts"][0]["text"],
|
||||||
|
role="assistant",
|
||||||
|
logprobs=None,
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=None,
|
||||||
|
)
|
||||||
|
choice = litellm.Choices(
|
||||||
|
finish_reason=candidate.get("finishReason", "stop"),
|
||||||
|
index=candidate.get("index", idx),
|
||||||
|
message=message,
|
||||||
|
logprobs=None,
|
||||||
|
enhancements=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.choices.append(choice)
|
||||||
|
|
||||||
|
## GET USAGE ##
|
||||||
|
usage = litellm.Usage(
|
||||||
|
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
|
||||||
|
completion_tokens=completion_response["usageMetadata"][
|
||||||
|
"candidatesTokenCount"
|
||||||
|
],
|
||||||
|
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
||||||
|
return vertex_region or "us-central1"
|
||||||
|
|
||||||
|
def load_auth(
|
||||||
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
|
) -> Tuple[Any, str]:
|
||||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||||
import google.auth as google_auth
|
import google.auth as google_auth
|
||||||
|
|
||||||
credentials, project_id = google_auth.default(
|
if credentials is not None and isinstance(credentials, str):
|
||||||
|
import google.oauth2.service_account
|
||||||
|
|
||||||
|
json_obj = json.loads(credentials)
|
||||||
|
|
||||||
|
creds = google.oauth2.service_account.Credentials.from_service_account_info(
|
||||||
|
json_obj,
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
)
|
)
|
||||||
|
|
||||||
credentials.refresh(Request())
|
if project_id is None:
|
||||||
|
project_id = creds.project_id
|
||||||
|
else:
|
||||||
|
creds, project_id = google_auth.default(
|
||||||
|
quota_project_id=project_id,
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
|
||||||
|
creds.refresh(Request())
|
||||||
|
|
||||||
if not project_id:
|
if not project_id:
|
||||||
raise ValueError("Could not resolve project_id")
|
raise ValueError("Could not resolve project_id")
|
||||||
|
@ -52,38 +154,135 @@ class VertexLLM(BaseLLM):
|
||||||
f"Expected project_id to be a str but got {type(project_id)}"
|
f"Expected project_id to be a str but got {type(project_id)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return credentials, project_id
|
return creds, project_id
|
||||||
|
|
||||||
def refresh_auth(self, credentials: Any) -> None:
|
def refresh_auth(self, credentials: Any) -> None:
|
||||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||||
|
|
||||||
credentials.refresh(Request())
|
credentials.refresh(Request())
|
||||||
|
|
||||||
def _prepare_request(self, request: httpx.Request) -> None:
|
def _ensure_access_token(
|
||||||
access_token = self._ensure_access_token()
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
|
) -> Tuple[str, str]:
|
||||||
if request.headers.get("Authorization"):
|
"""
|
||||||
# already authenticated, nothing for us to do
|
Returns auth token and project id
|
||||||
return
|
"""
|
||||||
|
if self.access_token is not None and self.project_id is not None:
|
||||||
request.headers["Authorization"] = f"Bearer {access_token}"
|
return self.access_token, self.project_id
|
||||||
|
|
||||||
def _ensure_access_token(self) -> str:
|
|
||||||
if self.access_token is not None:
|
|
||||||
return self.access_token
|
|
||||||
|
|
||||||
if not self._credentials:
|
if not self._credentials:
|
||||||
self._credentials, project_id = self.load_auth()
|
self._credentials, project_id = self.load_auth(
|
||||||
|
credentials=credentials, project_id=project_id
|
||||||
|
)
|
||||||
if not self.project_id:
|
if not self.project_id:
|
||||||
self.project_id = project_id
|
self.project_id = project_id
|
||||||
else:
|
else:
|
||||||
self.refresh_auth(self._credentials)
|
self.refresh_auth(self._credentials)
|
||||||
|
|
||||||
if not self._credentials.token:
|
if not self.project_id:
|
||||||
|
self.project_id = self._credentials.project_id
|
||||||
|
|
||||||
|
if not self.project_id:
|
||||||
|
raise ValueError("Could not resolve project_id")
|
||||||
|
|
||||||
|
if not self._credentials or not self._credentials.token:
|
||||||
raise RuntimeError("Could not resolve API token from the environment")
|
raise RuntimeError("Could not resolve API token from the environment")
|
||||||
|
|
||||||
assert isinstance(self._credentials.token, str)
|
return self._credentials.token, self.project_id
|
||||||
return self._credentials.token
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion: bool,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
|
auth_header, vertex_project = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||||
|
stream = optional_params.pop("stream", None)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
|
||||||
|
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
# Separate system prompt from rest of message
|
||||||
|
system_prompt_indices = []
|
||||||
|
system_content_blocks: List[PartType] = []
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
_system_content_block = PartType(text=message["content"])
|
||||||
|
system_content_blocks.append(_system_content_block)
|
||||||
|
system_prompt_indices.append(idx)
|
||||||
|
if len(system_prompt_indices) > 0:
|
||||||
|
for idx in reversed(system_prompt_indices):
|
||||||
|
messages.pop(idx)
|
||||||
|
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||||
|
content = _gemini_convert_messages_with_history(messages=messages)
|
||||||
|
|
||||||
|
data = RequestBody(system_instruction=system_instructions, contents=content)
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {auth_header}",
|
||||||
|
}
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": url,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
## COMPLETION CALL ##
|
||||||
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client
|
||||||
|
try:
|
||||||
|
response = client.post(url=url, headers=headers, json=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise VertexAIError(status_code=error_code, message=response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return self._process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key="",
|
||||||
|
data=data, # type: ignore
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
def image_generation(
|
def image_generation(
|
||||||
self,
|
self,
|
||||||
|
@ -163,7 +362,7 @@ class VertexLLM(BaseLLM):
|
||||||
} \
|
} \
|
||||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||||
"""
|
"""
|
||||||
auth_header = self._ensure_access_token()
|
auth_header, _ = self._ensure_access_token(credentials=None, project_id=None)
|
||||||
optional_params = optional_params or {
|
optional_params = optional_params or {
|
||||||
"sampleCount": 1
|
"sampleCount": 1
|
||||||
} # default optional params
|
} # default optional params
|
||||||
|
|
|
@ -1893,6 +1893,7 @@ def completion(
|
||||||
or optional_params.pop("vertex_ai_credentials", None)
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
or get_secret("VERTEXAI_CREDENTIALS")
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
)
|
)
|
||||||
|
|
||||||
new_params = deepcopy(optional_params)
|
new_params = deepcopy(optional_params)
|
||||||
if "claude-3" in model:
|
if "claude-3" in model:
|
||||||
model_response = vertex_ai_anthropic.completion(
|
model_response = vertex_ai_anthropic.completion(
|
||||||
|
@ -1910,6 +1911,26 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
model in litellm.vertex_language_models
|
||||||
|
or model in litellm.vertex_vision_models
|
||||||
|
):
|
||||||
|
model_response = vertex_chat_completion.completion( # type: ignore
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=new_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
vertex_location=vertex_ai_location,
|
||||||
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
logging_obj=logging,
|
||||||
|
acompletion=acompletion,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model_response = vertex_ai.completion(
|
model_response = vertex_ai.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -140,7 +140,7 @@ class _PROXY_AzureContentSafety(
|
||||||
response.choices[0], litellm.utils.Choices
|
response.choices[0], litellm.utils.Choices
|
||||||
):
|
):
|
||||||
await self.test_violation(
|
await self.test_violation(
|
||||||
content=response.choices[0].message.content, source="output"
|
content=response.choices[0].message.content or "", source="output"
|
||||||
)
|
)
|
||||||
|
|
||||||
# async def async_post_call_streaming_hook(
|
# async def async_post_call_streaming_hook(
|
||||||
|
|
|
@ -532,6 +532,8 @@ def test_gemini_pro_vision():
|
||||||
# DO Not DELETE this ASSERT
|
# DO Not DELETE this ASSERT
|
||||||
# Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response
|
# Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response
|
||||||
assert prompt_tokens == 263 # the gemini api returns 263 to us
|
assert prompt_tokens == 263 # the gemini api returns 263 to us
|
||||||
|
|
||||||
|
# assert False
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing_extensions import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
Required,
|
Required,
|
||||||
)
|
)
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class Field(TypedDict):
|
class Field(TypedDict):
|
||||||
|
@ -51,3 +52,161 @@ class PartType(TypedDict, total=False):
|
||||||
class ContentType(TypedDict, total=False):
|
class ContentType(TypedDict, total=False):
|
||||||
role: Literal["user", "model"]
|
role: Literal["user", "model"]
|
||||||
parts: Required[List[PartType]]
|
parts: Required[List[PartType]]
|
||||||
|
|
||||||
|
|
||||||
|
class SystemInstructions(TypedDict):
|
||||||
|
parts: Required[List[PartType]]
|
||||||
|
|
||||||
|
|
||||||
|
class Schema(TypedDict, total=False):
|
||||||
|
type: Literal["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"]
|
||||||
|
description: str
|
||||||
|
enum: List[str]
|
||||||
|
items: List["Schema"]
|
||||||
|
properties: "Schema"
|
||||||
|
required: List[str]
|
||||||
|
nullable: bool
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionDeclaration(TypedDict, total=False):
|
||||||
|
name: Required[str]
|
||||||
|
description: str
|
||||||
|
parameters: Schema
|
||||||
|
response: Schema
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCallingConfig(TypedDict, total=False):
|
||||||
|
mode: Literal["ANY", "AUTO", "NONE"]
|
||||||
|
allowed_function_names: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
HarmCategory = Literal[
|
||||||
|
"HARM_CATEGORY_UNSPECIFIED",
|
||||||
|
"HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"HARM_CATEGORY_HARASSMENT",
|
||||||
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
]
|
||||||
|
HarmBlockThreshold = Literal[
|
||||||
|
"HARM_BLOCK_THRESHOLD_UNSPECIFIED",
|
||||||
|
"BLOCK_LOW_AND_ABOVE",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"BLOCK_ONLY_HIGH",
|
||||||
|
"BLOCK_NONE",
|
||||||
|
]
|
||||||
|
HarmBlockMethod = Literal["HARM_BLOCK_METHOD_UNSPECIFIED", "SEVERITY", "PROBABILITY"]
|
||||||
|
|
||||||
|
HarmProbability = Literal[
|
||||||
|
"HARM_PROBABILITY_UNSPECIFIED", "NEGLIGIBLE", "LOW", "MEDIUM", "HIGH"
|
||||||
|
]
|
||||||
|
|
||||||
|
HarmSeverity = Literal[
|
||||||
|
"HARM_SEVERITY_UNSPECIFIED",
|
||||||
|
"HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"HARM_SEVERITY_LOW",
|
||||||
|
"HARM_SEVERITY_MEDIUM",
|
||||||
|
"HARM_SEVERITY_HIGH",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SafetSettingsConfig(TypedDict, total=False):
|
||||||
|
category: HarmCategory
|
||||||
|
threshold: HarmBlockThreshold
|
||||||
|
max_influential_terms: int
|
||||||
|
method: HarmBlockMethod
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationConfig(TypedDict, total=False):
|
||||||
|
temperature: float
|
||||||
|
top_p: float
|
||||||
|
top_k: float
|
||||||
|
candidate_count: int
|
||||||
|
max_output_tokens: int
|
||||||
|
stop_sequences: List[str]
|
||||||
|
presence_penalty: float
|
||||||
|
frequency_penalty: float
|
||||||
|
response_mime_type: Literal["text/plain", "application/json"]
|
||||||
|
|
||||||
|
|
||||||
|
class RequestBody(TypedDict, total=False):
|
||||||
|
contents: Required[List[ContentType]]
|
||||||
|
system_instruction: SystemInstructions
|
||||||
|
tools: FunctionDeclaration
|
||||||
|
tool_config: FunctionCallingConfig
|
||||||
|
safety_settings: SafetSettingsConfig
|
||||||
|
generation_config: GenerationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyRatings(TypedDict):
|
||||||
|
category: HarmCategory
|
||||||
|
probability: HarmProbability
|
||||||
|
probabilityScore: int
|
||||||
|
severity: HarmSeverity
|
||||||
|
blocked: bool
|
||||||
|
|
||||||
|
|
||||||
|
class Date(TypedDict):
|
||||||
|
year: int
|
||||||
|
month: int
|
||||||
|
date: int
|
||||||
|
|
||||||
|
|
||||||
|
class Citation(TypedDict):
|
||||||
|
startIndex: int
|
||||||
|
endIndex: int
|
||||||
|
uri: str
|
||||||
|
title: str
|
||||||
|
license: str
|
||||||
|
publicationDate: Date
|
||||||
|
|
||||||
|
|
||||||
|
class CitationMetadata(TypedDict):
|
||||||
|
citations: List[Citation]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchEntryPoint(TypedDict, total=False):
|
||||||
|
renderedContent: str
|
||||||
|
sdkBlob: str
|
||||||
|
|
||||||
|
|
||||||
|
class GroundingMetadata(TypedDict, total=False):
|
||||||
|
webSearchQueries: List[str]
|
||||||
|
searchEntryPoint: SearchEntryPoint
|
||||||
|
|
||||||
|
|
||||||
|
class Candidates(TypedDict, total=False):
|
||||||
|
index: int
|
||||||
|
content: ContentType
|
||||||
|
finishReason: Literal[
|
||||||
|
"FINISH_REASON_UNSPECIFIED",
|
||||||
|
"STOP",
|
||||||
|
"MAX_TOKENS",
|
||||||
|
"SAFETY",
|
||||||
|
"RECITATION",
|
||||||
|
"OTHER",
|
||||||
|
"BLOCKLIST",
|
||||||
|
"PROHIBITED_CONTENT",
|
||||||
|
"SPII",
|
||||||
|
]
|
||||||
|
safetyRatings: SafetyRatings
|
||||||
|
citationMetadata: CitationMetadata
|
||||||
|
groundingMetadata: GroundingMetadata
|
||||||
|
finishMessage: str
|
||||||
|
|
||||||
|
|
||||||
|
class PromptFeedback(TypedDict):
|
||||||
|
blockReason: str
|
||||||
|
safetyRatings: List[SafetyRatings]
|
||||||
|
blockReasonMessage: str
|
||||||
|
|
||||||
|
|
||||||
|
class UsageMetadata(TypedDict):
|
||||||
|
promptTokenCount: int
|
||||||
|
totalTokenCount: int
|
||||||
|
candidatesTokenCount: int
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateContentResponseBody(TypedDict, total=False):
|
||||||
|
candidates: Required[List[Candidates]]
|
||||||
|
promptFeedback: PromptFeedback
|
||||||
|
usageMetadata: Required[UsageMetadata]
|
||||||
|
|
|
@ -518,15 +518,18 @@ class Choices(OpenAIObject):
|
||||||
self,
|
self,
|
||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
index=0,
|
index=0,
|
||||||
message=None,
|
message: Optional[Union[Message, dict]] = None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
enhancements=None,
|
enhancements=None,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
super(Choices, self).__init__(**params)
|
super(Choices, self).__init__(**params)
|
||||||
self.finish_reason = (
|
if finish_reason is not None:
|
||||||
map_finish_reason(finish_reason) or "stop"
|
self.finish_reason = map_finish_reason(
|
||||||
|
finish_reason
|
||||||
) # set finish_reason for all responses
|
) # set finish_reason for all responses
|
||||||
|
else:
|
||||||
|
self.finish_reason = "stop"
|
||||||
self.index = index
|
self.index = index
|
||||||
if message is None:
|
if message is None:
|
||||||
self.message = Message()
|
self.message = Message()
|
||||||
|
@ -2822,7 +2825,9 @@ class Rules:
|
||||||
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore
|
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def post_call_rules(self, input: str, model: str):
|
def post_call_rules(self, input: Optional[str], model: str) -> bool:
|
||||||
|
if input is None:
|
||||||
|
return True
|
||||||
for rule in litellm.post_call_rules:
|
for rule in litellm.post_call_rules:
|
||||||
if callable(rule):
|
if callable(rule):
|
||||||
decision = rule(input)
|
decision = rule(input)
|
||||||
|
@ -3101,9 +3106,9 @@ def client(original_function):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(original_response, ModelResponse):
|
if isinstance(original_response, ModelResponse):
|
||||||
model_response = original_response["choices"][0]["message"][
|
model_response = original_response.choices[
|
||||||
"content"
|
0
|
||||||
]
|
].message.content
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
rules_obj.post_call_rules(input=model_response, model=model)
|
rules_obj.post_call_rules(input=model_response, model=model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue