forked from phoenix/litellm-mirror
Merge pull request #4478 from BerriAI/litellm_support_response_schema_param_vertex_ai_old
feat(vertex_httpx.py): support the 'response_schema' param for older vertex ai models
This commit is contained in:
commit
58d0330cd7
14 changed files with 444 additions and 171 deletions
|
@ -66,7 +66,7 @@ jobs:
|
|||
pip install "pydantic==2.7.1"
|
||||
pip install "diskcache==5.6.1"
|
||||
pip install "Pillow==10.3.0"
|
||||
pip install "ijson==3.2.3"
|
||||
pip install "jsonschema==4.22.0"
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
|
@ -128,7 +128,7 @@ jobs:
|
|||
pip install jinja2
|
||||
pip install tokenizers
|
||||
pip install openai
|
||||
pip install ijson
|
||||
pip install jsonschema
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
|
@ -183,7 +183,7 @@ jobs:
|
|||
pip install numpydoc
|
||||
pip install prisma
|
||||
pip install fastapi
|
||||
pip install ijson
|
||||
pip install jsonschema
|
||||
pip install "httpx==0.24.1"
|
||||
pip install "gunicorn==21.2.0"
|
||||
pip install "anyio==3.7.1"
|
||||
|
|
|
@ -749,6 +749,7 @@ from .utils import (
|
|||
create_pretrained_tokenizer,
|
||||
create_tokenizer,
|
||||
supports_function_calling,
|
||||
supports_response_schema,
|
||||
supports_parallel_function_calling,
|
||||
supports_vision,
|
||||
supports_system_messages,
|
||||
|
@ -852,6 +853,7 @@ from .exceptions import (
|
|||
APIResponseValidationError,
|
||||
UnprocessableEntityError,
|
||||
InternalServerError,
|
||||
JSONSchemaValidationError,
|
||||
LITELLM_EXCEPTION_TYPES,
|
||||
)
|
||||
from .budget_manager import BudgetManager
|
||||
|
|
|
@ -551,7 +551,7 @@ class APIError(openai.APIError): # type: ignore
|
|||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
request: httpx.Request,
|
||||
request: Optional[httpx.Request] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
|
@ -563,6 +563,8 @@ class APIError(openai.APIError): # type: ignore
|
|||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
if request is None:
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||
|
||||
def __str__(self):
|
||||
|
@ -664,6 +666,22 @@ class OpenAIError(openai.OpenAIError): # type: ignore
|
|||
self.llm_provider = "openai"
|
||||
|
||||
|
||||
class JSONSchemaValidationError(APIError):
|
||||
def __init__(
|
||||
self, model: str, llm_provider: str, raw_response: str, schema: str
|
||||
) -> None:
|
||||
self.raw_response = raw_response
|
||||
self.schema = schema
|
||||
self.model = model
|
||||
message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
|
||||
model, raw_response, schema
|
||||
)
|
||||
self.message = message
|
||||
super().__init__(
|
||||
model=model, message=message, llm_provider=llm_provider, status_code=500
|
||||
)
|
||||
|
||||
|
||||
LITELLM_EXCEPTION_TYPES = [
|
||||
AuthenticationError,
|
||||
NotFoundError,
|
||||
|
@ -682,6 +700,7 @@ LITELLM_EXCEPTION_TYPES = [
|
|||
APIResponseValidationError,
|
||||
OpenAIError,
|
||||
InternalServerError,
|
||||
JSONSchemaValidationError,
|
||||
]
|
||||
|
||||
|
||||
|
|
23
litellm/litellm_core_utils/json_validation_rule.py
Normal file
23
litellm/litellm_core_utils/json_validation_rule.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import json
|
||||
|
||||
|
||||
def validate_schema(schema: dict, response: str):
|
||||
"""
|
||||
Validate if the returned json response follows the schema.
|
||||
|
||||
Params:
|
||||
- schema - dict: JSON schema
|
||||
- response - str: Received json response as string.
|
||||
"""
|
||||
from jsonschema import ValidationError, validate
|
||||
|
||||
from litellm import JSONSchemaValidationError
|
||||
|
||||
response_dict = json.loads(response)
|
||||
|
||||
try:
|
||||
validate(response_dict, schema=schema)
|
||||
except ValidationError:
|
||||
raise JSONSchemaValidationError(
|
||||
model="", llm_provider="", raw_response=response, schema=json.dumps(schema)
|
||||
)
|
|
@ -2033,6 +2033,50 @@ def function_call_prompt(messages: list, functions: list):
|
|||
return messages
|
||||
|
||||
|
||||
def response_schema_prompt(model: str, response_schema: dict) -> str:
|
||||
"""
|
||||
Decides if a user-defined custom prompt or default needs to be used
|
||||
|
||||
Returns the prompt str that's passed to the model as a user message
|
||||
"""
|
||||
custom_prompt_details: Optional[dict] = None
|
||||
response_schema_as_message = [
|
||||
{"role": "user", "content": "{}".format(response_schema)}
|
||||
]
|
||||
if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict:
|
||||
|
||||
custom_prompt_details = litellm.custom_prompt_dict[
|
||||
f"{model}/response_schema_prompt"
|
||||
] # allow user to define custom response schema prompt by model
|
||||
elif "response_schema_prompt" in litellm.custom_prompt_dict:
|
||||
custom_prompt_details = litellm.custom_prompt_dict["response_schema_prompt"]
|
||||
|
||||
if custom_prompt_details is not None:
|
||||
return custom_prompt(
|
||||
role_dict=custom_prompt_details["roles"],
|
||||
initial_prompt_value=custom_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=custom_prompt_details["final_prompt_value"],
|
||||
messages=response_schema_as_message,
|
||||
)
|
||||
else:
|
||||
return default_response_schema_prompt(response_schema=response_schema)
|
||||
|
||||
|
||||
def default_response_schema_prompt(response_schema: dict) -> str:
|
||||
"""
|
||||
Used if provider/model doesn't support 'response_schema' param.
|
||||
|
||||
This is the default prompt. Allow user to override this with a custom_prompt.
|
||||
"""
|
||||
prompt_str = """Use this JSON schema:
|
||||
```json
|
||||
{}
|
||||
```""".format(
|
||||
response_schema
|
||||
)
|
||||
return prompt_str
|
||||
|
||||
|
||||
# Custom prompt template
|
||||
def custom_prompt(
|
||||
role_dict: dict,
|
||||
|
|
|
@ -12,6 +12,7 @@ import requests # type: ignore
|
|||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_to_anthropic_image_obj,
|
||||
|
@ -328,80 +329,86 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
|
|||
contents: List[ContentType] = []
|
||||
|
||||
msg_i = 0
|
||||
while msg_i < len(messages):
|
||||
user_content: List[PartType] = []
|
||||
init_msg_i = msg_i
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
_parts: List[PartType] = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text" and len(element["text"]) > 0:
|
||||
_part = PartType(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
_part = _process_gemini_image(image_url=image_url)
|
||||
_parts.append(_part) # type: ignore
|
||||
user_content.extend(_parts)
|
||||
elif (
|
||||
isinstance(messages[msg_i]["content"], str)
|
||||
and len(messages[msg_i]["content"]) > 0
|
||||
try:
|
||||
while msg_i < len(messages):
|
||||
user_content: List[PartType] = []
|
||||
init_msg_i = msg_i
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while (
|
||||
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
|
||||
):
|
||||
_part = PartType(text=messages[msg_i]["content"])
|
||||
user_content.append(_part)
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
_parts: List[PartType] = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text" and len(element["text"]) > 0:
|
||||
_part = PartType(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
_part = _process_gemini_image(image_url=image_url)
|
||||
_parts.append(_part) # type: ignore
|
||||
user_content.extend(_parts)
|
||||
elif (
|
||||
isinstance(messages[msg_i]["content"], str)
|
||||
and len(messages[msg_i]["content"]) > 0
|
||||
):
|
||||
_part = PartType(text=messages[msg_i]["content"])
|
||||
user_content.append(_part)
|
||||
|
||||
msg_i += 1
|
||||
msg_i += 1
|
||||
|
||||
if user_content:
|
||||
contents.append(ContentType(role="user", parts=user_content))
|
||||
assistant_content = []
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
_parts = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
_part = PartType(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
_part = _process_gemini_image(image_url=image_url)
|
||||
_parts.append(_part) # type: ignore
|
||||
assistant_content.extend(_parts)
|
||||
elif messages[msg_i].get(
|
||||
"tool_calls", []
|
||||
): # support assistant tool invoke conversion
|
||||
assistant_content.extend(
|
||||
convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
|
||||
if user_content:
|
||||
contents.append(ContentType(role="user", parts=user_content))
|
||||
assistant_content = []
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
_parts = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
_part = PartType(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
_part = _process_gemini_image(image_url=image_url)
|
||||
_parts.append(_part) # type: ignore
|
||||
assistant_content.extend(_parts)
|
||||
elif messages[msg_i].get(
|
||||
"tool_calls", []
|
||||
): # support assistant tool invoke conversion
|
||||
assistant_content.extend(
|
||||
convert_to_gemini_tool_call_invoke(
|
||||
messages[msg_i]["tool_calls"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
assistant_text = (
|
||||
messages[msg_i].get("content") or ""
|
||||
) # either string or none
|
||||
if assistant_text:
|
||||
assistant_content.append(PartType(text=assistant_text))
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content:
|
||||
contents.append(ContentType(role="model", parts=assistant_content))
|
||||
|
||||
## APPEND TOOL CALL MESSAGES ##
|
||||
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
|
||||
_part = convert_to_gemini_tool_call_result(messages[msg_i])
|
||||
contents.append(ContentType(parts=[_part])) # type: ignore
|
||||
msg_i += 1
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise Exception(
|
||||
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||
messages[msg_i]
|
||||
)
|
||||
)
|
||||
else:
|
||||
assistant_text = (
|
||||
messages[msg_i].get("content") or ""
|
||||
) # either string or none
|
||||
if assistant_text:
|
||||
assistant_content.append(PartType(text=assistant_text))
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content:
|
||||
contents.append(ContentType(role="model", parts=assistant_content))
|
||||
|
||||
## APPEND TOOL CALL MESSAGES ##
|
||||
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
|
||||
_part = convert_to_gemini_tool_call_result(messages[msg_i])
|
||||
contents.append(ContentType(parts=[_part])) # type: ignore
|
||||
msg_i += 1
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise Exception(
|
||||
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||
messages[msg_i]
|
||||
)
|
||||
)
|
||||
|
||||
return contents
|
||||
return contents
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
|
||||
|
|
|
@ -12,7 +12,6 @@ from functools import partial
|
|||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import ijson
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
|
@ -21,7 +20,10 @@ import litellm.litellm_core_utils.litellm_logging
|
|||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.prompt_templates.factory import convert_url_to_base64
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_url_to_base64,
|
||||
response_schema_prompt,
|
||||
)
|
||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionResponseMessage,
|
||||
|
@ -1011,35 +1013,53 @@ class VertexLLM(BaseLLM):
|
|||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||
"safety_settings", None
|
||||
) # type: ignore
|
||||
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||
**optional_params
|
||||
)
|
||||
data = RequestBody(contents=content)
|
||||
if len(system_content_blocks) > 0:
|
||||
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||
data["system_instruction"] = system_instructions
|
||||
if tools is not None:
|
||||
data["tools"] = tools
|
||||
if tool_choice is not None:
|
||||
data["toolConfig"] = tool_choice
|
||||
if safety_settings is not None:
|
||||
data["safetySettings"] = safety_settings
|
||||
if generation_config is not None:
|
||||
data["generationConfig"] = generation_config
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if auth_header is not None:
|
||||
headers["Authorization"] = f"Bearer {auth_header}"
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
# Checks for 'response_schema' support - if passed in
|
||||
if "response_schema" in optional_params:
|
||||
supports_response_schema = litellm.supports_response_schema(
|
||||
model=model, custom_llm_provider="vertex_ai"
|
||||
)
|
||||
if supports_response_schema is False:
|
||||
user_response_schema_message = response_schema_prompt(
|
||||
model=model, response_schema=optional_params.get("response_schema") # type: ignore
|
||||
)
|
||||
messages.append(
|
||||
{"role": "user", "content": user_response_schema_message}
|
||||
)
|
||||
optional_params.pop("response_schema")
|
||||
|
||||
try:
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||
"safety_settings", None
|
||||
) # type: ignore
|
||||
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||
**optional_params
|
||||
)
|
||||
data = RequestBody(contents=content)
|
||||
if len(system_content_blocks) > 0:
|
||||
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||
data["system_instruction"] = system_instructions
|
||||
if tools is not None:
|
||||
data["tools"] = tools
|
||||
if tool_choice is not None:
|
||||
data["toolConfig"] = tool_choice
|
||||
if safety_settings is not None:
|
||||
data["safetySettings"] = safety_settings
|
||||
if generation_config is not None:
|
||||
data["generationConfig"] = generation_config
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if auth_header is not None:
|
||||
headers["Authorization"] = f"Bearer {auth_header}"
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
|
|
@ -1538,6 +1538,7 @@
|
|||
"supports_system_messages": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini-1.5-pro-preview-0215": {
|
||||
|
@ -1563,6 +1564,7 @@
|
|||
"supports_system_messages": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini-1.5-pro-preview-0409": {
|
||||
|
@ -1586,7 +1588,8 @@
|
|||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini-1.5-flash": {
|
||||
|
|
|
@ -880,10 +880,141 @@ Using this JSON schema:
|
|||
mock_call.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||
def vertex_httpx_mock_post_valid_response(*args, **kwargs):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": '[{"recipe_name": "Chocolate Chip Cookies"}, {"recipe_name": "Oatmeal Raisin Cookies"}, {"recipe_name": "Peanut Butter Cookies"}, {"recipe_name": "Sugar Cookies"}, {"recipe_name": "Snickerdoodles"}]\n'
|
||||
}
|
||||
],
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.09790669,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.11736965,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.1261379,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.08601588,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.083441176,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.0355444,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.071981624,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.08108212,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 60,
|
||||
"candidatesTokenCount": 55,
|
||||
"totalTokenCount": 115,
|
||||
},
|
||||
}
|
||||
return mock_response
|
||||
|
||||
|
||||
def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"text": '[{"recipe_world": "Chocolate Chip Cookies"}]\n'}
|
||||
],
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.09790669,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.11736965,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.1261379,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.08601588,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.083441176,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.0355444,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.071981624,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.08108212,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 60,
|
||||
"candidatesTokenCount": 55,
|
||||
"totalTokenCount": 115,
|
||||
},
|
||||
}
|
||||
return mock_response
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, supports_response_schema",
|
||||
[
|
||||
("vertex_ai_beta/gemini-1.5-pro-001", True),
|
||||
("vertex_ai_beta/gemini-1.5-flash", False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_response",
|
||||
[True, False],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"enforce_validation",
|
||||
[True, False],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pro_json_schema_httpx(provider):
|
||||
async def test_gemini_pro_json_schema_args_sent_httpx(
|
||||
model, supports_response_schema, invalid_response, enforce_validation
|
||||
):
|
||||
load_vertex_ai_credentials()
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
litellm.set_verbose = True
|
||||
messages = [{"role": "user", "content": "List 5 cookie recipes"}]
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
@ -903,26 +1034,47 @@ async def test_gemini_pro_json_schema_httpx(provider):
|
|||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post", new=MagicMock()) as mock_call:
|
||||
httpx_response = MagicMock()
|
||||
if invalid_response is True:
|
||||
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
|
||||
else:
|
||||
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
|
||||
with patch.object(client, "post", new=httpx_response) as mock_call:
|
||||
try:
|
||||
response = completion(
|
||||
model="vertex_ai_beta/gemini-1.5-pro-001",
|
||||
_ = completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format={
|
||||
"type": "json_object",
|
||||
"response_schema": response_schema,
|
||||
"enforce_validation": enforce_validation,
|
||||
},
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
if invalid_response is True and enforce_validation is True:
|
||||
pytest.fail("Expected this to fail")
|
||||
except litellm.JSONSchemaValidationError as e:
|
||||
if invalid_response is False:
|
||||
pytest.fail("Expected this to pass. Got={}".format(e))
|
||||
|
||||
mock_call.assert_called_once()
|
||||
print(mock_call.call_args.kwargs)
|
||||
print(mock_call.call_args.kwargs["json"]["generationConfig"])
|
||||
assert (
|
||||
"response_schema" in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
|
||||
if supports_response_schema:
|
||||
assert (
|
||||
"response_schema"
|
||||
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
"response_schema"
|
||||
not in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
assert (
|
||||
"Use this JSON schema:"
|
||||
in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||
|
@ -959,48 +1111,6 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
|
|||
assert "hello" in mock_call.call_args.kwargs["headers"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pro_httpx_custom_api_base_streaming_real_call(
|
||||
provider, sync_mode
|
||||
):
|
||||
load_vertex_ai_credentials()
|
||||
import random
|
||||
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey, how's it going?",
|
||||
}
|
||||
]
|
||||
|
||||
vertex_region = random.sample(["asia-southeast1", "us-central1"], k=1)[0]
|
||||
if sync_mode is True:
|
||||
response = completion(
|
||||
model="vertex_ai_beta/gemini-1.5-flash",
|
||||
messages=messages,
|
||||
api_base="https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/google-vertex-ai/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-flash",
|
||||
stream=True,
|
||||
vertex_region=vertex_region,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="vertex_ai_beta/gemini-1.5-flash",
|
||||
messages=messages,
|
||||
api_base="https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/google-vertex-ai/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-flash",
|
||||
stream=True,
|
||||
vertex_region=vertex_region,
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
||||
@pytest.mark.parametrize("sync_mode", [True])
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai"])
|
||||
|
|
|
@ -61,7 +61,6 @@ async def test_token_single_public_key():
|
|||
import jwt
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
backend_keys = {
|
||||
"keys": [
|
||||
{
|
||||
|
|
|
@ -48,6 +48,7 @@ from tokenizers import Tokenizer
|
|||
import litellm
|
||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.json_validation_rule
|
||||
from litellm.caching import DualCache
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.exception_mapping_utils import get_error_message
|
||||
|
@ -580,7 +581,7 @@ def client(original_function):
|
|||
else:
|
||||
return False
|
||||
|
||||
def post_call_processing(original_response, model):
|
||||
def post_call_processing(original_response, model, optional_params: Optional[dict]):
|
||||
try:
|
||||
if original_response is None:
|
||||
pass
|
||||
|
@ -595,11 +596,47 @@ def client(original_function):
|
|||
pass
|
||||
else:
|
||||
if isinstance(original_response, ModelResponse):
|
||||
model_response = original_response.choices[
|
||||
model_response: Optional[str] = original_response.choices[
|
||||
0
|
||||
].message.content
|
||||
### POST-CALL RULES ###
|
||||
rules_obj.post_call_rules(input=model_response, model=model)
|
||||
].message.content # type: ignore
|
||||
if model_response is not None:
|
||||
### POST-CALL RULES ###
|
||||
rules_obj.post_call_rules(
|
||||
input=model_response, model=model
|
||||
)
|
||||
### JSON SCHEMA VALIDATION ###
|
||||
if (
|
||||
optional_params is not None
|
||||
and "response_format" in optional_params
|
||||
and isinstance(
|
||||
optional_params["response_format"], dict
|
||||
)
|
||||
and "type" in optional_params["response_format"]
|
||||
and optional_params["response_format"]["type"]
|
||||
== "json_object"
|
||||
and "response_schema"
|
||||
in optional_params["response_format"]
|
||||
and isinstance(
|
||||
optional_params["response_format"][
|
||||
"response_schema"
|
||||
],
|
||||
dict,
|
||||
)
|
||||
and "enforce_validation"
|
||||
in optional_params["response_format"]
|
||||
and optional_params["response_format"][
|
||||
"enforce_validation"
|
||||
]
|
||||
is True
|
||||
):
|
||||
# schema given, json response expected, and validation enforced
|
||||
litellm.litellm_core_utils.json_validation_rule.validate_schema(
|
||||
schema=optional_params["response_format"][
|
||||
"response_schema"
|
||||
],
|
||||
response=model_response,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -868,7 +905,11 @@ def client(original_function):
|
|||
return result
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model or None)
|
||||
post_call_processing(
|
||||
original_response=result,
|
||||
model=model or None,
|
||||
optional_params=kwargs,
|
||||
)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if (
|
||||
|
@ -1317,7 +1358,9 @@ def client(original_function):
|
|||
).total_seconds() * 1000 # return response latency in ms like openai
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model)
|
||||
post_call_processing(
|
||||
original_response=result, model=model, optional_params=kwargs
|
||||
)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if (
|
||||
|
@ -1880,8 +1923,7 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) ->
|
|||
Returns:
|
||||
bool: True if the model supports response_schema, False otherwise.
|
||||
|
||||
Raises:
|
||||
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||
Does not raise error. Defaults to 'False'. Outputs logging.error.
|
||||
"""
|
||||
try:
|
||||
## GET LLM PROVIDER ##
|
||||
|
@ -1901,9 +1943,10 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) ->
|
|||
return True
|
||||
return False
|
||||
except Exception:
|
||||
raise Exception(
|
||||
verbose_logger.error(
|
||||
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def supports_function_calling(model: str) -> bool:
|
||||
|
|
|
@ -1538,6 +1538,7 @@
|
|||
"supports_system_messages": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini-1.5-pro-preview-0215": {
|
||||
|
@ -1563,6 +1564,7 @@
|
|||
"supports_system_messages": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini-1.5-pro-preview-0409": {
|
||||
|
@ -1586,7 +1588,8 @@
|
|||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini-1.5-flash": {
|
||||
|
|
|
@ -27,7 +27,7 @@ jinja2 = "^3.1.2"
|
|||
aiohttp = "*"
|
||||
requests = "^2.31.0"
|
||||
pydantic = "^2.0.0"
|
||||
ijson = "*"
|
||||
jsonschema = "^4.22.0"
|
||||
|
||||
uvicorn = {version = "^0.22.0", optional = true}
|
||||
gunicorn = {version = "^22.0.0", optional = true}
|
||||
|
|
|
@ -46,5 +46,5 @@ aiohttp==3.9.0 # for network calls
|
|||
aioboto3==12.3.0 # for async sagemaker calls
|
||||
tenacity==8.2.3 # for retrying requests, when litellm.num_retries set
|
||||
pydantic==2.7.1 # proxy + openai req.
|
||||
ijson==3.2.3 # for google ai studio streaming
|
||||
jsonschema==4.22.0 # validating json schema
|
||||
####
|
Loading…
Add table
Add a link
Reference in a new issue