forked from phoenix/litellm-mirror
feat(vertex_httpx.py): Moving to call vertex ai via httpx (instead of their sdk). Allows us to support all their api updates.
This commit is contained in:
parent
fb96f07ccb
commit
3b913443fe
8 changed files with 431 additions and 40 deletions
|
@ -93,7 +93,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
|
|||
response.choices[0], litellm.utils.Choices
|
||||
):
|
||||
for word in self.banned_keywords_list:
|
||||
self.test_violation(test_str=response.choices[0].message.content)
|
||||
self.test_violation(test_str=response.choices[0].message.content or "")
|
||||
|
||||
async def async_post_call_streaming_hook(
|
||||
self,
|
||||
|
|
|
@ -12,7 +12,12 @@ from litellm.llms.prompt_templates.factory import (
|
|||
convert_to_gemini_tool_call_result,
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
)
|
||||
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
|
||||
from litellm.types.files import (
|
||||
get_file_mime_type_for_file_type,
|
||||
get_file_type_from_extension,
|
||||
is_gemini_1_5_accepted_file_type,
|
||||
is_video_file_type,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -301,15 +306,15 @@ def _process_gemini_image(image_url: str) -> PartType:
|
|||
# GCS URIs
|
||||
if "gs://" in image_url:
|
||||
# Figure out file type
|
||||
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
|
||||
extension = extension_with_dot[1:] # Ex: "png"
|
||||
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
|
||||
extension = extension_with_dot[1:] # Ex: "png"
|
||||
|
||||
file_type = get_file_type_from_extension(extension)
|
||||
|
||||
# Validate the file type is supported by Gemini
|
||||
if not is_gemini_1_5_accepted_file_type(file_type):
|
||||
raise Exception(f"File type not supported by gemini - {file_type}")
|
||||
|
||||
|
||||
mime_type = get_file_mime_type_for_file_type(file_type)
|
||||
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
|
||||
|
||||
|
@ -320,7 +325,7 @@ def _process_gemini_image(image_url: str) -> PartType:
|
|||
image = _load_image_from_url(image_url)
|
||||
_blob = BlobType(data=image.data, mime_type=image._mime_type)
|
||||
return PartType(inline_data=_blob)
|
||||
|
||||
|
||||
# Base64 encoding
|
||||
elif "base64" in image_url:
|
||||
import base64, re
|
||||
|
@ -611,7 +616,7 @@ def completion(
|
|||
llm_model = None
|
||||
|
||||
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
|
||||
if acompletion == True:
|
||||
if acompletion is True:
|
||||
data = {
|
||||
"llm_model": llm_model,
|
||||
"mode": mode,
|
||||
|
@ -643,7 +648,7 @@ def completion(
|
|||
tools = optional_params.pop("tools", None)
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
stream = optional_params.pop("stream", False)
|
||||
if stream == True:
|
||||
if stream is True:
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
|
|
@ -9,6 +9,14 @@ import litellm, uuid
|
|||
import httpx, inspect # type: ignore
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from .base import BaseLLM
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
ContentType,
|
||||
SystemInstructions,
|
||||
PartType,
|
||||
RequestBody,
|
||||
GenerateContentResponseBody,
|
||||
)
|
||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -33,16 +41,110 @@ class VertexLLM(BaseLLM):
|
|||
self.project_id: Optional[str] = None
|
||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||
|
||||
def load_auth(self) -> Tuple[Any, str]:
|
||||
def _process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
|
||||
except Exception as e:
|
||||
raise VertexAIError(
|
||||
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||
response.text, str(e)
|
||||
),
|
||||
status_code=422,
|
||||
)
|
||||
|
||||
model_response.choices = []
|
||||
|
||||
## GET MODEL ##
|
||||
model_response.model = model
|
||||
## GET TEXT ##
|
||||
for idx, candidate in enumerate(completion_response["candidates"]):
|
||||
if candidate.get("content", None) is None:
|
||||
continue
|
||||
|
||||
message = litellm.Message(
|
||||
content=candidate["content"]["parts"][0]["text"],
|
||||
role="assistant",
|
||||
logprobs=None,
|
||||
function_call=None,
|
||||
tool_calls=None,
|
||||
)
|
||||
choice = litellm.Choices(
|
||||
finish_reason=candidate.get("finishReason", "stop"),
|
||||
index=candidate.get("index", idx),
|
||||
message=message,
|
||||
logprobs=None,
|
||||
enhancements=None,
|
||||
)
|
||||
|
||||
model_response.choices.append(choice)
|
||||
|
||||
## GET USAGE ##
|
||||
usage = litellm.Usage(
|
||||
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
|
||||
completion_tokens=completion_response["usageMetadata"][
|
||||
"candidatesTokenCount"
|
||||
],
|
||||
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
|
||||
)
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
||||
return vertex_region or "us-central1"
|
||||
|
||||
def load_auth(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
) -> Tuple[Any, str]:
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||
import google.auth as google_auth
|
||||
|
||||
credentials, project_id = google_auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
if credentials is not None and isinstance(credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
credentials.refresh(Request())
|
||||
json_obj = json.loads(credentials)
|
||||
|
||||
creds = google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
project_id = creds.project_id
|
||||
else:
|
||||
creds, project_id = google_auth.default(
|
||||
quota_project_id=project_id,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
|
||||
creds.refresh(Request())
|
||||
|
||||
if not project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
@ -52,38 +154,135 @@ class VertexLLM(BaseLLM):
|
|||
f"Expected project_id to be a str but got {type(project_id)}"
|
||||
)
|
||||
|
||||
return credentials, project_id
|
||||
return creds, project_id
|
||||
|
||||
def refresh_auth(self, credentials: Any) -> None:
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
def _prepare_request(self, request: httpx.Request) -> None:
|
||||
access_token = self._ensure_access_token()
|
||||
|
||||
if request.headers.get("Authorization"):
|
||||
# already authenticated, nothing for us to do
|
||||
return
|
||||
|
||||
request.headers["Authorization"] = f"Bearer {access_token}"
|
||||
|
||||
def _ensure_access_token(self) -> str:
|
||||
if self.access_token is not None:
|
||||
return self.access_token
|
||||
def _ensure_access_token(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns auth token and project id
|
||||
"""
|
||||
if self.access_token is not None and self.project_id is not None:
|
||||
return self.access_token, self.project_id
|
||||
|
||||
if not self._credentials:
|
||||
self._credentials, project_id = self.load_auth()
|
||||
self._credentials, project_id = self.load_auth(
|
||||
credentials=credentials, project_id=project_id
|
||||
)
|
||||
if not self.project_id:
|
||||
self.project_id = project_id
|
||||
else:
|
||||
self.refresh_auth(self._credentials)
|
||||
|
||||
if not self._credentials.token:
|
||||
if not self.project_id:
|
||||
self.project_id = self._credentials.project_id
|
||||
|
||||
if not self.project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
if not self._credentials or not self._credentials.token:
|
||||
raise RuntimeError("Could not resolve API token from the environment")
|
||||
|
||||
assert isinstance(self._credentials.token, str)
|
||||
return self._credentials.token
|
||||
return self._credentials.token, self.project_id
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
||||
auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
|
||||
|
||||
## TRANSFORMATION ##
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_indices = []
|
||||
system_content_blocks: List[PartType] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
_system_content_block = PartType(text=message["content"])
|
||||
system_content_blocks.append(_system_content_block)
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
|
||||
data = RequestBody(system_instruction=system_instructions, contents=content)
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
## COMPLETION CALL ##
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
try:
|
||||
response = client.post(url=url, headers=headers, json=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise VertexAIError(status_code=error_code, message=response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return self._process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_key="",
|
||||
data=data, # type: ignore
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
|
@ -163,7 +362,7 @@ class VertexLLM(BaseLLM):
|
|||
} \
|
||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||
"""
|
||||
auth_header = self._ensure_access_token()
|
||||
auth_header, _ = self._ensure_access_token(credentials=None, project_id=None)
|
||||
optional_params = optional_params or {
|
||||
"sampleCount": 1
|
||||
} # default optional params
|
||||
|
|
|
@ -1893,6 +1893,7 @@ def completion(
|
|||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
new_params = deepcopy(optional_params)
|
||||
if "claude-3" in model:
|
||||
model_response = vertex_ai_anthropic.completion(
|
||||
|
@ -1910,6 +1911,26 @@ def completion(
|
|||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
elif (
|
||||
model in litellm.vertex_language_models
|
||||
or model in litellm.vertex_vision_models
|
||||
):
|
||||
model_response = vertex_chat_completion.completion( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=new_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_credentials=vertex_credentials,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
model_response = vertex_ai.completion(
|
||||
model=model,
|
||||
|
|
|
@ -140,7 +140,7 @@ class _PROXY_AzureContentSafety(
|
|||
response.choices[0], litellm.utils.Choices
|
||||
):
|
||||
await self.test_violation(
|
||||
content=response.choices[0].message.content, source="output"
|
||||
content=response.choices[0].message.content or "", source="output"
|
||||
)
|
||||
|
||||
# async def async_post_call_streaming_hook(
|
||||
|
|
|
@ -532,6 +532,8 @@ def test_gemini_pro_vision():
|
|||
# DO Not DELETE this ASSERT
|
||||
# Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response
|
||||
assert prompt_tokens == 263 # the gemini api returns 263 to us
|
||||
|
||||
# assert False
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing_extensions import (
|
|||
runtime_checkable,
|
||||
Required,
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Field(TypedDict):
|
||||
|
@ -51,3 +52,161 @@ class PartType(TypedDict, total=False):
|
|||
class ContentType(TypedDict, total=False):
|
||||
role: Literal["user", "model"]
|
||||
parts: Required[List[PartType]]
|
||||
|
||||
|
||||
class SystemInstructions(TypedDict):
|
||||
parts: Required[List[PartType]]
|
||||
|
||||
|
||||
class Schema(TypedDict, total=False):
|
||||
type: Literal["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"]
|
||||
description: str
|
||||
enum: List[str]
|
||||
items: List["Schema"]
|
||||
properties: "Schema"
|
||||
required: List[str]
|
||||
nullable: bool
|
||||
|
||||
|
||||
class FunctionDeclaration(TypedDict, total=False):
|
||||
name: Required[str]
|
||||
description: str
|
||||
parameters: Schema
|
||||
response: Schema
|
||||
|
||||
|
||||
class FunctionCallingConfig(TypedDict, total=False):
|
||||
mode: Literal["ANY", "AUTO", "NONE"]
|
||||
allowed_function_names: List[str]
|
||||
|
||||
|
||||
HarmCategory = Literal[
|
||||
"HARM_CATEGORY_UNSPECIFIED",
|
||||
"HARM_CATEGORY_HATE_SPEECH",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"HARM_CATEGORY_HARASSMENT",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
]
|
||||
HarmBlockThreshold = Literal[
|
||||
"HARM_BLOCK_THRESHOLD_UNSPECIFIED",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_NONE",
|
||||
]
|
||||
HarmBlockMethod = Literal["HARM_BLOCK_METHOD_UNSPECIFIED", "SEVERITY", "PROBABILITY"]
|
||||
|
||||
HarmProbability = Literal[
|
||||
"HARM_PROBABILITY_UNSPECIFIED", "NEGLIGIBLE", "LOW", "MEDIUM", "HIGH"
|
||||
]
|
||||
|
||||
HarmSeverity = Literal[
|
||||
"HARM_SEVERITY_UNSPECIFIED",
|
||||
"HARM_SEVERITY_NEGLIGIBLE",
|
||||
"HARM_SEVERITY_LOW",
|
||||
"HARM_SEVERITY_MEDIUM",
|
||||
"HARM_SEVERITY_HIGH",
|
||||
]
|
||||
|
||||
|
||||
class SafetSettingsConfig(TypedDict, total=False):
|
||||
category: HarmCategory
|
||||
threshold: HarmBlockThreshold
|
||||
max_influential_terms: int
|
||||
method: HarmBlockMethod
|
||||
|
||||
|
||||
class GenerationConfig(TypedDict, total=False):
|
||||
temperature: float
|
||||
top_p: float
|
||||
top_k: float
|
||||
candidate_count: int
|
||||
max_output_tokens: int
|
||||
stop_sequences: List[str]
|
||||
presence_penalty: float
|
||||
frequency_penalty: float
|
||||
response_mime_type: Literal["text/plain", "application/json"]
|
||||
|
||||
|
||||
class RequestBody(TypedDict, total=False):
|
||||
contents: Required[List[ContentType]]
|
||||
system_instruction: SystemInstructions
|
||||
tools: FunctionDeclaration
|
||||
tool_config: FunctionCallingConfig
|
||||
safety_settings: SafetSettingsConfig
|
||||
generation_config: GenerationConfig
|
||||
|
||||
|
||||
class SafetyRatings(TypedDict):
|
||||
category: HarmCategory
|
||||
probability: HarmProbability
|
||||
probabilityScore: int
|
||||
severity: HarmSeverity
|
||||
blocked: bool
|
||||
|
||||
|
||||
class Date(TypedDict):
|
||||
year: int
|
||||
month: int
|
||||
date: int
|
||||
|
||||
|
||||
class Citation(TypedDict):
|
||||
startIndex: int
|
||||
endIndex: int
|
||||
uri: str
|
||||
title: str
|
||||
license: str
|
||||
publicationDate: Date
|
||||
|
||||
|
||||
class CitationMetadata(TypedDict):
|
||||
citations: List[Citation]
|
||||
|
||||
|
||||
class SearchEntryPoint(TypedDict, total=False):
|
||||
renderedContent: str
|
||||
sdkBlob: str
|
||||
|
||||
|
||||
class GroundingMetadata(TypedDict, total=False):
|
||||
webSearchQueries: List[str]
|
||||
searchEntryPoint: SearchEntryPoint
|
||||
|
||||
|
||||
class Candidates(TypedDict, total=False):
|
||||
index: int
|
||||
content: ContentType
|
||||
finishReason: Literal[
|
||||
"FINISH_REASON_UNSPECIFIED",
|
||||
"STOP",
|
||||
"MAX_TOKENS",
|
||||
"SAFETY",
|
||||
"RECITATION",
|
||||
"OTHER",
|
||||
"BLOCKLIST",
|
||||
"PROHIBITED_CONTENT",
|
||||
"SPII",
|
||||
]
|
||||
safetyRatings: SafetyRatings
|
||||
citationMetadata: CitationMetadata
|
||||
groundingMetadata: GroundingMetadata
|
||||
finishMessage: str
|
||||
|
||||
|
||||
class PromptFeedback(TypedDict):
|
||||
blockReason: str
|
||||
safetyRatings: List[SafetyRatings]
|
||||
blockReasonMessage: str
|
||||
|
||||
|
||||
class UsageMetadata(TypedDict):
|
||||
promptTokenCount: int
|
||||
totalTokenCount: int
|
||||
candidatesTokenCount: int
|
||||
|
||||
|
||||
class GenerateContentResponseBody(TypedDict, total=False):
|
||||
candidates: Required[List[Candidates]]
|
||||
promptFeedback: PromptFeedback
|
||||
usageMetadata: Required[UsageMetadata]
|
||||
|
|
|
@ -518,15 +518,18 @@ class Choices(OpenAIObject):
|
|||
self,
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
message=None,
|
||||
message: Optional[Union[Message, dict]] = None,
|
||||
logprobs=None,
|
||||
enhancements=None,
|
||||
**params,
|
||||
):
|
||||
super(Choices, self).__init__(**params)
|
||||
self.finish_reason = (
|
||||
map_finish_reason(finish_reason) or "stop"
|
||||
) # set finish_reason for all responses
|
||||
if finish_reason is not None:
|
||||
self.finish_reason = map_finish_reason(
|
||||
finish_reason
|
||||
) # set finish_reason for all responses
|
||||
else:
|
||||
self.finish_reason = "stop"
|
||||
self.index = index
|
||||
if message is None:
|
||||
self.message = Message()
|
||||
|
@ -2822,7 +2825,9 @@ class Rules:
|
|||
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore
|
||||
return True
|
||||
|
||||
def post_call_rules(self, input: str, model: str):
|
||||
def post_call_rules(self, input: Optional[str], model: str) -> bool:
|
||||
if input is None:
|
||||
return True
|
||||
for rule in litellm.post_call_rules:
|
||||
if callable(rule):
|
||||
decision = rule(input)
|
||||
|
@ -3101,9 +3106,9 @@ def client(original_function):
|
|||
pass
|
||||
else:
|
||||
if isinstance(original_response, ModelResponse):
|
||||
model_response = original_response["choices"][0]["message"][
|
||||
"content"
|
||||
]
|
||||
model_response = original_response.choices[
|
||||
0
|
||||
].message.content
|
||||
### POST-CALL RULES ###
|
||||
rules_obj.post_call_rules(input=model_response, model=model)
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue