mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Merge pull request #4160 from BerriAI/litellm_vertex_completion_httpx
feat(vertex_httpx.py): Support Vertex AI system messages, JSON Schema, etc.
This commit is contained in:
commit
c6a4f44d76
14 changed files with 1333 additions and 61 deletions
|
@ -8,6 +8,152 @@ import TabItem from '@theme/TabItem';
|
||||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
|
## 🆕 `vertex_ai_beta/` route
|
||||||
|
|
||||||
|
New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import json
|
||||||
|
|
||||||
|
## GET CREDENTIALS
|
||||||
|
file_path = 'path/to/vertex_ai_service_account.json'
|
||||||
|
|
||||||
|
# Load the JSON file
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
vertex_credentials = json.load(file)
|
||||||
|
|
||||||
|
# Convert to JSON string
|
||||||
|
vertex_credentials_json = json.dumps(vertex_credentials)
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai_beta/gemini-pro",
|
||||||
|
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
vertex_credentials=vertex_credentials_json
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### **System Message**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import json
|
||||||
|
|
||||||
|
## GET CREDENTIALS
|
||||||
|
file_path = 'path/to/vertex_ai_service_account.json'
|
||||||
|
|
||||||
|
# Load the JSON file
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
vertex_credentials = json.load(file)
|
||||||
|
|
||||||
|
# Convert to JSON string
|
||||||
|
vertex_credentials_json = json.dumps(vertex_credentials)
|
||||||
|
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai_beta/gemini-pro",
|
||||||
|
messages=[{"content": "You are a good bot.","role": "system"}, {"content": "Hello, how are you?","role": "user"}],
|
||||||
|
vertex_credentials=vertex_credentials_json
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### **Function Calling**
|
||||||
|
|
||||||
|
Force Gemini to make tool calls with `tool_choice="required"`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import json
|
||||||
|
|
||||||
|
## GET CREDENTIALS
|
||||||
|
file_path = 'path/to/vertex_ai_service_account.json'
|
||||||
|
|
||||||
|
# Load the JSON file
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
vertex_credentials = json.load(file)
|
||||||
|
|
||||||
|
# Convert to JSON string
|
||||||
|
vertex_credentials_json = json.dumps(vertex_credentials)
|
||||||
|
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||||
|
},
|
||||||
|
# User asks for their name and weather in San Francisco
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, what is your name and can you tell me the weather?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "vertex_ai_beta/gemini-1.5-pro-preview-0514"),
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"tool_choice": "required",
|
||||||
|
"vertex_credentials": vertex_credentials_json
|
||||||
|
}
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
print(completion(**data))
|
||||||
|
```
|
||||||
|
|
||||||
|
### **JSON Schema**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
## GET CREDENTIALS
|
||||||
|
file_path = 'path/to/vertex_ai_service_account.json'
|
||||||
|
|
||||||
|
# Load the JSON file
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
vertex_credentials = json.load(file)
|
||||||
|
|
||||||
|
# Convert to JSON string
|
||||||
|
vertex_credentials_json = json.dumps(vertex_credentials)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": """
|
||||||
|
List 5 popular cookie recipes.
|
||||||
|
|
||||||
|
Using this JSON schema:
|
||||||
|
|
||||||
|
Recipe = {"recipe_name": str}
|
||||||
|
|
||||||
|
Return a `list[Recipe]`
|
||||||
|
"""
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
completion(model="vertex_ai_beta/gemini-1.5-flash-preview-0514", messages=messages, response_format={ "type": "json_object" })
|
||||||
|
```
|
||||||
|
|
||||||
## Pre-requisites
|
## Pre-requisites
|
||||||
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
|
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
|
||||||
* Authentication:
|
* Authentication:
|
||||||
|
@ -140,7 +286,7 @@ In certain use-cases you may need to make calls to the models and pass [safety s
|
||||||
|
|
||||||
```python
|
```python
|
||||||
response = completion(
|
response = completion(
|
||||||
model="gemini/gemini-pro",
|
model="vertex_ai/gemini-pro",
|
||||||
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
|
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
|
||||||
safety_settings=[
|
safety_settings=[
|
||||||
{
|
{
|
||||||
|
@ -680,6 +826,3 @@ s/o @[Darien Kindlund](https://www.linkedin.com/in/kindlund/) for this tutorial
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -93,7 +93,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
|
||||||
response.choices[0], litellm.utils.Choices
|
response.choices[0], litellm.utils.Choices
|
||||||
):
|
):
|
||||||
for word in self.banned_keywords_list:
|
for word in self.banned_keywords_list:
|
||||||
self.test_violation(test_str=response.choices[0].message.content)
|
self.test_violation(test_str=response.choices[0].message.content or "")
|
||||||
|
|
||||||
async def async_post_call_streaming_hook(
|
async def async_post_call_streaming_hook(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -605,6 +605,7 @@ provider_list: List = [
|
||||||
"together_ai",
|
"together_ai",
|
||||||
"openrouter",
|
"openrouter",
|
||||||
"vertex_ai",
|
"vertex_ai",
|
||||||
|
"vertex_ai_beta",
|
||||||
"palm",
|
"palm",
|
||||||
"gemini",
|
"gemini",
|
||||||
"ai21",
|
"ai21",
|
||||||
|
@ -765,6 +766,7 @@ from .llms.gemini import GeminiConfig
|
||||||
from .llms.nlp_cloud import NLPCloudConfig
|
from .llms.nlp_cloud import NLPCloudConfig
|
||||||
from .llms.aleph_alpha import AlephAlphaConfig
|
from .llms.aleph_alpha import AlephAlphaConfig
|
||||||
from .llms.petals import PetalsConfig
|
from .llms.petals import PetalsConfig
|
||||||
|
from .llms.vertex_httpx import VertexGeminiConfig
|
||||||
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
|
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
|
||||||
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
||||||
from .llms.sagemaker import SagemakerConfig
|
from .llms.sagemaker import SagemakerConfig
|
||||||
|
|
|
@ -617,7 +617,7 @@ def completion(
|
||||||
llm_model = None
|
llm_model = None
|
||||||
|
|
||||||
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
|
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
data = {
|
data = {
|
||||||
"llm_model": llm_model,
|
"llm_model": llm_model,
|
||||||
"mode": mode,
|
"mode": mode,
|
||||||
|
@ -649,7 +649,7 @@ def completion(
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
content = _gemini_convert_messages_with_history(messages=messages)
|
content = _gemini_convert_messages_with_history(messages=messages)
|
||||||
stream = optional_params.pop("stream", False)
|
stream = optional_params.pop("stream", False)
|
||||||
if stream == True:
|
if stream is True:
|
||||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
# What is this?
|
||||||
|
## httpx client for vertex ai calls
|
||||||
|
## Initial implementation - covers gemini + image gen calls
|
||||||
|
from functools import partial
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -9,6 +13,284 @@ import litellm, uuid
|
||||||
import httpx, inspect # type: ignore
|
import httpx, inspect # type: ignore
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
ContentType,
|
||||||
|
SystemInstructions,
|
||||||
|
PartType,
|
||||||
|
RequestBody,
|
||||||
|
GenerateContentResponseBody,
|
||||||
|
FunctionCallingConfig,
|
||||||
|
FunctionDeclaration,
|
||||||
|
Tools,
|
||||||
|
ToolConfig,
|
||||||
|
GenerationConfig,
|
||||||
|
)
|
||||||
|
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||||
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
ChatCompletionResponseMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VertexGeminiConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
||||||
|
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
|
|
||||||
|
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
|
||||||
|
|
||||||
|
- `temperature` (float): This controls the degree of randomness in token selection.
|
||||||
|
|
||||||
|
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
|
||||||
|
|
||||||
|
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
|
||||||
|
|
||||||
|
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
||||||
|
|
||||||
|
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
|
||||||
|
|
||||||
|
- `candidate_count` (int): Number of generated responses to return.
|
||||||
|
|
||||||
|
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
|
||||||
|
|
||||||
|
- `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
|
||||||
|
|
||||||
|
- `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
|
||||||
|
|
||||||
|
Note: Please make sure to modify the default parameters as required for your use case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_output_tokens: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
response_mime_type: Optional[str] = None
|
||||||
|
candidate_count: Optional[int] = None
|
||||||
|
stop_sequences: Optional[list] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[list] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"response_format",
|
||||||
|
"n",
|
||||||
|
"stop",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_tool_choice_values(
|
||||||
|
self, model: str, tool_choice: Union[str, dict]
|
||||||
|
) -> Optional[ToolConfig]:
|
||||||
|
if tool_choice == "none":
|
||||||
|
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE"))
|
||||||
|
elif tool_choice == "required":
|
||||||
|
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY"))
|
||||||
|
elif tool_choice == "auto":
|
||||||
|
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO"))
|
||||||
|
elif isinstance(tool_choice, dict):
|
||||||
|
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||||
|
name = tool_choice.get("function", {}).get("name", "")
|
||||||
|
return ToolConfig(
|
||||||
|
functionCallingConfig=FunctionCallingConfig(
|
||||||
|
mode="ANY", allowed_function_names=[name]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||||
|
tool_choice
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if (
|
||||||
|
param == "stream" and value is True
|
||||||
|
): # sending stream = False, can cause it to get passed unchecked and raise issues
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["candidate_count"] = value
|
||||||
|
if param == "stop":
|
||||||
|
if isinstance(value, str):
|
||||||
|
optional_params["stop_sequences"] = [value]
|
||||||
|
elif isinstance(value, list):
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_output_tokens"] = value
|
||||||
|
if param == "response_format" and value["type"] == "json_object": # type: ignore
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
|
if param == "frequency_penalty":
|
||||||
|
optional_params["frequency_penalty"] = value
|
||||||
|
if param == "presence_penalty":
|
||||||
|
optional_params["presence_penalty"] = value
|
||||||
|
if param == "tools" and isinstance(value, list):
|
||||||
|
gtool_func_declarations = []
|
||||||
|
for tool in value:
|
||||||
|
gtool_func_declaration = FunctionDeclaration(
|
||||||
|
name=tool["function"]["name"],
|
||||||
|
description=tool["function"].get("description", ""),
|
||||||
|
parameters=tool["function"].get("parameters", {}),
|
||||||
|
)
|
||||||
|
gtool_func_declarations.append(gtool_func_declaration)
|
||||||
|
optional_params["tools"] = [
|
||||||
|
Tools(function_declarations=gtool_func_declarations)
|
||||||
|
]
|
||||||
|
if param == "tool_choice" and (
|
||||||
|
isinstance(value, str) or isinstance(value, dict)
|
||||||
|
):
|
||||||
|
_tool_choice_value = self.map_tool_choice_values(
|
||||||
|
model=model, tool_choice=value # type: ignore
|
||||||
|
)
|
||||||
|
if _tool_choice_value is not None:
|
||||||
|
optional_params["tool_choice"] = _tool_choice_value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
"""
|
||||||
|
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||||
|
"""
|
||||||
|
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||||
|
|
||||||
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
mapped_params = self.get_mapped_special_auth_params()
|
||||||
|
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in mapped_params:
|
||||||
|
optional_params[mapped_params[param]] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_eu_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"europe-central2",
|
||||||
|
"europe-north1",
|
||||||
|
"europe-southwest1",
|
||||||
|
"europe-west1",
|
||||||
|
"europe-west2",
|
||||||
|
"europe-west3",
|
||||||
|
"europe-west4",
|
||||||
|
"europe-west6",
|
||||||
|
"europe-west8",
|
||||||
|
"europe-west9",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def make_call(
|
||||||
|
client: Optional[AsyncHTTPHandler],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
if client is None:
|
||||||
|
client = AsyncHTTPHandler() # Create a new client if none provided
|
||||||
|
|
||||||
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise VertexAIError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
streaming_response=response.aiter_bytes(chunk_size=2056)
|
||||||
|
)
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response="first stream response received",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
|
def make_sync_call(
|
||||||
|
client: Optional[HTTPHandler],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
if client is None:
|
||||||
|
client = HTTPHandler() # Create a new client if none provided
|
||||||
|
|
||||||
|
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise VertexAIError(status_code=response.status_code, message=response.read())
|
||||||
|
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
streaming_response=response.iter_bytes(chunk_size=2056)
|
||||||
|
)
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response="first stream response received",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -33,16 +315,125 @@ class VertexLLM(BaseLLM):
|
||||||
self.project_id: Optional[str] = None
|
self.project_id: Optional[str] = None
|
||||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||||
|
|
||||||
def load_auth(self) -> Tuple[Any, str]:
|
def _process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: litellm.utils.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(
|
||||||
|
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
response.text, str(e)
|
||||||
|
),
|
||||||
|
status_code=422,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.choices = [] # type: ignore
|
||||||
|
|
||||||
|
## GET MODEL ##
|
||||||
|
model_response.model = model
|
||||||
|
## GET TEXT ##
|
||||||
|
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||||
|
content_str = ""
|
||||||
|
tools: List[ChatCompletionToolCallChunk] = []
|
||||||
|
for idx, candidate in enumerate(completion_response["candidates"]):
|
||||||
|
if "content" not in candidate:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "text" in candidate["content"]["parts"][0]:
|
||||||
|
content_str = candidate["content"]["parts"][0]["text"]
|
||||||
|
|
||||||
|
if "functionCall" in candidate["content"]["parts"][0]:
|
||||||
|
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||||
|
name=candidate["content"]["parts"][0]["functionCall"]["name"],
|
||||||
|
arguments=json.dumps(
|
||||||
|
candidate["content"]["parts"][0]["functionCall"]["args"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||||
|
id=f"call_{str(uuid.uuid4())}",
|
||||||
|
type="function",
|
||||||
|
function=_function_chunk,
|
||||||
|
)
|
||||||
|
tools.append(_tool_response_chunk)
|
||||||
|
|
||||||
|
chat_completion_message["content"] = content_str
|
||||||
|
chat_completion_message["tool_calls"] = tools
|
||||||
|
|
||||||
|
choice = litellm.Choices(
|
||||||
|
finish_reason=candidate.get("finishReason", "stop"),
|
||||||
|
index=candidate.get("index", idx),
|
||||||
|
message=chat_completion_message, # type: ignore
|
||||||
|
logprobs=None,
|
||||||
|
enhancements=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.choices.append(choice)
|
||||||
|
|
||||||
|
## GET USAGE ##
|
||||||
|
usage = litellm.Usage(
|
||||||
|
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
|
||||||
|
completion_tokens=completion_response["usageMetadata"][
|
||||||
|
"candidatesTokenCount"
|
||||||
|
],
|
||||||
|
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
||||||
|
return vertex_region or "us-central1"
|
||||||
|
|
||||||
|
def load_auth(
|
||||||
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
|
) -> Tuple[Any, str]:
|
||||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||||
import google.auth as google_auth
|
import google.auth as google_auth
|
||||||
|
|
||||||
credentials, project_id = google_auth.default(
|
if credentials is not None and isinstance(credentials, str):
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
import google.oauth2.service_account
|
||||||
)
|
|
||||||
|
|
||||||
credentials.refresh(Request())
|
json_obj = json.loads(credentials)
|
||||||
|
|
||||||
|
creds = google.oauth2.service_account.Credentials.from_service_account_info(
|
||||||
|
json_obj,
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if project_id is None:
|
||||||
|
project_id = creds.project_id
|
||||||
|
else:
|
||||||
|
creds, project_id = google_auth.default(
|
||||||
|
quota_project_id=project_id,
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
|
||||||
|
creds.refresh(Request())
|
||||||
|
|
||||||
if not project_id:
|
if not project_id:
|
||||||
raise ValueError("Could not resolve project_id")
|
raise ValueError("Could not resolve project_id")
|
||||||
|
@ -52,38 +443,272 @@ class VertexLLM(BaseLLM):
|
||||||
f"Expected project_id to be a str but got {type(project_id)}"
|
f"Expected project_id to be a str but got {type(project_id)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return credentials, project_id
|
return creds, project_id
|
||||||
|
|
||||||
def refresh_auth(self, credentials: Any) -> None:
|
def refresh_auth(self, credentials: Any) -> None:
|
||||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||||
|
|
||||||
credentials.refresh(Request())
|
credentials.refresh(Request())
|
||||||
|
|
||||||
def _prepare_request(self, request: httpx.Request) -> None:
|
def _ensure_access_token(
|
||||||
access_token = self._ensure_access_token()
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
|
) -> Tuple[str, str]:
|
||||||
if request.headers.get("Authorization"):
|
"""
|
||||||
# already authenticated, nothing for us to do
|
Returns auth token and project id
|
||||||
return
|
"""
|
||||||
|
if self.access_token is not None and self.project_id is not None:
|
||||||
request.headers["Authorization"] = f"Bearer {access_token}"
|
return self.access_token, self.project_id
|
||||||
|
|
||||||
def _ensure_access_token(self) -> str:
|
|
||||||
if self.access_token is not None:
|
|
||||||
return self.access_token
|
|
||||||
|
|
||||||
if not self._credentials:
|
if not self._credentials:
|
||||||
self._credentials, project_id = self.load_auth()
|
self._credentials, project_id = self.load_auth(
|
||||||
|
credentials=credentials, project_id=project_id
|
||||||
|
)
|
||||||
if not self.project_id:
|
if not self.project_id:
|
||||||
self.project_id = project_id
|
self.project_id = project_id
|
||||||
else:
|
else:
|
||||||
self.refresh_auth(self._credentials)
|
self.refresh_auth(self._credentials)
|
||||||
|
|
||||||
if not self._credentials.token:
|
if not self.project_id:
|
||||||
|
self.project_id = self._credentials.project_id
|
||||||
|
|
||||||
|
if not self.project_id:
|
||||||
|
raise ValueError("Could not resolve project_id")
|
||||||
|
|
||||||
|
if not self._credentials or not self._credentials.token:
|
||||||
raise RuntimeError("Could not resolve API token from the environment")
|
raise RuntimeError("Could not resolve API token from the environment")
|
||||||
|
|
||||||
assert isinstance(self._credentials.token, str)
|
return self._credentials.token, self.project_id
|
||||||
return self._credentials.token
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_call,
|
||||||
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="vertex_ai_beta",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(api_base, headers=headers, json=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return self._process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion: bool,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
|
auth_header, vertex_project = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||||
|
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
|
||||||
|
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
# Separate system prompt from rest of message
|
||||||
|
system_prompt_indices = []
|
||||||
|
system_content_blocks: List[PartType] = []
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
_system_content_block = PartType(text=message["content"])
|
||||||
|
system_content_blocks.append(_system_content_block)
|
||||||
|
system_prompt_indices.append(idx)
|
||||||
|
if len(system_prompt_indices) > 0:
|
||||||
|
for idx in reversed(system_prompt_indices):
|
||||||
|
messages.pop(idx)
|
||||||
|
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||||
|
content = _gemini_convert_messages_with_history(messages=messages)
|
||||||
|
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||||
|
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||||
|
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
data = RequestBody(system_instruction=system_instructions, contents=content)
|
||||||
|
if tools is not None:
|
||||||
|
data["tools"] = tools
|
||||||
|
if tool_choice is not None:
|
||||||
|
data["toolConfig"] = tool_choice
|
||||||
|
if generation_config is not None:
|
||||||
|
data["generationConfig"] = generation_config
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {auth_header}",
|
||||||
|
}
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": url,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if acompletion:
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data, # type: ignore
|
||||||
|
api_base=url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
## SYNC STREAMING CALL ##
|
||||||
|
if stream is not None and stream is True:
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_sync_call,
|
||||||
|
client=None,
|
||||||
|
api_base=url,
|
||||||
|
headers=headers, # type: ignore
|
||||||
|
data=json.dumps(data),
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="vertex_ai_beta",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return streaming_response
|
||||||
|
## COMPLETION CALL ##
|
||||||
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client
|
||||||
|
try:
|
||||||
|
response = client.post(url=url, headers=headers, json=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise VertexAIError(status_code=error_code, message=response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return self._process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key="",
|
||||||
|
data=data, # type: ignore
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
def image_generation(
|
def image_generation(
|
||||||
self,
|
self,
|
||||||
|
@ -163,7 +788,7 @@ class VertexLLM(BaseLLM):
|
||||||
} \
|
} \
|
||||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||||
"""
|
"""
|
||||||
auth_header = self._ensure_access_token()
|
auth_header, _ = self._ensure_access_token(credentials=None, project_id=None)
|
||||||
optional_params = optional_params or {
|
optional_params = optional_params or {
|
||||||
"sampleCount": 1
|
"sampleCount": 1
|
||||||
} # default optional params
|
} # default optional params
|
||||||
|
@ -222,3 +847,84 @@ class VertexLLM(BaseLLM):
|
||||||
model_response.data = _response_data
|
model_response.data = _response_data
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
class ModelResponseIterator:
|
||||||
|
def __init__(self, streaming_response):
|
||||||
|
self.streaming_response = streaming_response
|
||||||
|
self.response_iterator = iter(self.streaming_response)
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore
|
||||||
|
text = ""
|
||||||
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
|
|
||||||
|
gemini_chunk = processed_chunk["candidates"][0]
|
||||||
|
|
||||||
|
if (
|
||||||
|
"content" in gemini_chunk
|
||||||
|
and "text" in gemini_chunk["content"]["parts"][0]
|
||||||
|
):
|
||||||
|
text = gemini_chunk["content"]["parts"][0]["text"]
|
||||||
|
|
||||||
|
if "finishReason" in gemini_chunk:
|
||||||
|
finish_reason = map_finish_reason(
|
||||||
|
finish_reason=gemini_chunk["finishReason"]
|
||||||
|
)
|
||||||
|
is_finished = True
|
||||||
|
|
||||||
|
if "usageMetadata" in processed_chunk:
|
||||||
|
usage = ChatCompletionUsageBlock(
|
||||||
|
prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"],
|
||||||
|
completion_tokens=processed_chunk["usageMetadata"][
|
||||||
|
"candidatesTokenCount"
|
||||||
|
],
|
||||||
|
total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"],
|
||||||
|
)
|
||||||
|
|
||||||
|
returned_chunk = GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_use=tool_use,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
return returned_chunk
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
try:
|
||||||
|
chunk = next(self.response_iterator)
|
||||||
|
chunk = chunk.decode()
|
||||||
|
json_chunk = json.loads(chunk)
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e}")
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
|
chunk = chunk.decode()
|
||||||
|
json_chunk = json.loads(chunk)
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e}")
|
||||||
|
|
|
@ -329,6 +329,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "ollama_chat"
|
or custom_llm_provider == "ollama_chat"
|
||||||
or custom_llm_provider == "replicate"
|
or custom_llm_provider == "replicate"
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
|
or custom_llm_provider == "vertex_ai_beta"
|
||||||
or custom_llm_provider == "gemini"
|
or custom_llm_provider == "gemini"
|
||||||
or custom_llm_provider == "sagemaker"
|
or custom_llm_provider == "sagemaker"
|
||||||
or custom_llm_provider == "anthropic"
|
or custom_llm_provider == "anthropic"
|
||||||
|
@ -1876,6 +1877,42 @@ def completion(
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
|
elif custom_llm_provider == "vertex_ai_beta":
|
||||||
|
vertex_ai_project = (
|
||||||
|
optional_params.pop("vertex_project", None)
|
||||||
|
or optional_params.pop("vertex_ai_project", None)
|
||||||
|
or litellm.vertex_project
|
||||||
|
or get_secret("VERTEXAI_PROJECT")
|
||||||
|
)
|
||||||
|
vertex_ai_location = (
|
||||||
|
optional_params.pop("vertex_location", None)
|
||||||
|
or optional_params.pop("vertex_ai_location", None)
|
||||||
|
or litellm.vertex_location
|
||||||
|
or get_secret("VERTEXAI_LOCATION")
|
||||||
|
)
|
||||||
|
vertex_credentials = (
|
||||||
|
optional_params.pop("vertex_credentials", None)
|
||||||
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
|
)
|
||||||
|
new_params = deepcopy(optional_params)
|
||||||
|
response = vertex_chat_completion.completion( # type: ignore
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=new_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
vertex_location=vertex_ai_location,
|
||||||
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
logging_obj=logging,
|
||||||
|
acompletion=acompletion,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
vertex_ai_project = (
|
vertex_ai_project = (
|
||||||
optional_params.pop("vertex_project", None)
|
optional_params.pop("vertex_project", None)
|
||||||
|
@ -1894,6 +1931,7 @@ def completion(
|
||||||
or optional_params.pop("vertex_ai_credentials", None)
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
or get_secret("VERTEXAI_CREDENTIALS")
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
)
|
)
|
||||||
|
|
||||||
new_params = deepcopy(optional_params)
|
new_params = deepcopy(optional_params)
|
||||||
if "claude-3" in model:
|
if "claude-3" in model:
|
||||||
model_response = vertex_ai_anthropic.completion(
|
model_response = vertex_ai_anthropic.completion(
|
||||||
|
|
|
@ -140,7 +140,7 @@ class _PROXY_AzureContentSafety(
|
||||||
response.choices[0], litellm.utils.Choices
|
response.choices[0], litellm.utils.Choices
|
||||||
):
|
):
|
||||||
await self.test_violation(
|
await self.test_violation(
|
||||||
content=response.choices[0].message.content, source="output"
|
content=response.choices[0].message.content or "", source="output"
|
||||||
)
|
)
|
||||||
|
|
||||||
# async def async_post_call_streaming_hook(
|
# async def async_post_call_streaming_hook(
|
||||||
|
|
|
@ -503,28 +503,50 @@ async def test_async_vertexai_streaming_response():
|
||||||
# asyncio.run(test_async_vertexai_streaming_response())
|
# asyncio.run(test_async_vertexai_streaming_response())
|
||||||
|
|
||||||
|
|
||||||
def test_gemini_pro_vision():
|
@pytest.mark.parametrize("provider", ["vertex_ai", "vertex_ai_beta"])
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_pro_vision(provider, sync_mode):
|
||||||
try:
|
try:
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
litellm.num_retries = 3
|
litellm.num_retries = 3
|
||||||
resp = litellm.completion(
|
if sync_mode:
|
||||||
model="vertex_ai/gemini-1.5-flash-preview-0514",
|
resp = litellm.completion(
|
||||||
messages=[
|
model="{}/gemini-1.5-flash-preview-0514".format(provider),
|
||||||
{
|
messages=[
|
||||||
"role": "user",
|
{
|
||||||
"content": [
|
"role": "user",
|
||||||
{"type": "text", "text": "Whats in this image?"},
|
"content": [
|
||||||
{
|
{"type": "text", "text": "Whats in this image?"},
|
||||||
"type": "image_url",
|
{
|
||||||
"image_url": {
|
"type": "image_url",
|
||||||
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
"image_url": {
|
||||||
|
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
],
|
||||||
],
|
}
|
||||||
}
|
],
|
||||||
],
|
)
|
||||||
)
|
else:
|
||||||
|
resp = await litellm.acompletion(
|
||||||
|
model="{}/gemini-1.5-flash-preview-0514".format(provider),
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Whats in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
print(resp)
|
print(resp)
|
||||||
|
|
||||||
prompt_tokens = resp.usage.prompt_tokens
|
prompt_tokens = resp.usage.prompt_tokens
|
||||||
|
@ -532,6 +554,8 @@ def test_gemini_pro_vision():
|
||||||
# DO Not DELETE this ASSERT
|
# DO Not DELETE this ASSERT
|
||||||
# Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response
|
# Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response
|
||||||
assert prompt_tokens == 263 # the gemini api returns 263 to us
|
assert prompt_tokens == 263 # the gemini api returns 263 to us
|
||||||
|
|
||||||
|
# assert False
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -591,9 +615,111 @@ def test_gemini_pro_vision_base64():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
||||||
|
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True]) # "vertex_ai",
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_pro_function_calling(sync_mode):
|
async def test_gemini_pro_function_calling_httpx(provider, sync_mode):
|
||||||
|
try:
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||||
|
},
|
||||||
|
# User asks for their name and weather in San Francisco
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, what is your name and can you tell me the weather?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "{}/gemini-1.5-pro".format(provider),
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"tool_choice": "required",
|
||||||
|
}
|
||||||
|
if sync_mode:
|
||||||
|
response = litellm.completion(**data)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(**data)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
assert response.choices[0].message.tool_calls[0].function.arguments is not None
|
||||||
|
assert isinstance(
|
||||||
|
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||||
|
)
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
if "429 Quota exceeded" in str(e):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
||||||
|
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_pro_json_schema_httpx(provider):
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": """
|
||||||
|
List 5 popular cookie recipes.
|
||||||
|
|
||||||
|
Using this JSON schema:
|
||||||
|
|
||||||
|
Recipe = {"recipe_name": str}
|
||||||
|
|
||||||
|
Return a `list[Recipe]`
|
||||||
|
""",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai_beta/gemini-1.5-flash-preview-0514",
|
||||||
|
messages=messages,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
response_json = json.loads(response.choices[0].message.content)
|
||||||
|
|
||||||
|
assert isinstance(response_json, dict) or isinstance(response_json, list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True])
|
||||||
|
@pytest.mark.parametrize("provider", ["vertex_ai"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_pro_function_calling(provider, sync_mode):
|
||||||
try:
|
try:
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -655,7 +781,7 @@ async def test_gemini_pro_function_calling(sync_mode):
|
||||||
]
|
]
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": "vertex_ai/gemini-1.5-pro-preview-0514",
|
"model": "{}/gemini-1.5-pro-preview-0514".format(provider),
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"tools": tools,
|
"tools": tools,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1035,7 +1035,8 @@ def test_completion_claude_stream_bad_key():
|
||||||
# test_completion_replicate_stream()
|
# test_completion_replicate_stream()
|
||||||
|
|
||||||
|
|
||||||
def test_vertex_ai_stream():
|
@pytest.mark.parametrize("provider", ["vertex_ai", "vertex_ai_beta"])
|
||||||
|
def test_vertex_ai_stream(provider):
|
||||||
from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials
|
from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials
|
||||||
|
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
|
@ -1048,7 +1049,7 @@ def test_vertex_ai_stream():
|
||||||
try:
|
try:
|
||||||
print("making request", model)
|
print("making request", model)
|
||||||
response = completion(
|
response = completion(
|
||||||
model=model,
|
model="{}/{}".format(provider, model),
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "user", "content": "write 10 line code code for saying hi"}
|
{"role": "user", "content": "write 10 line code code for saying hi"}
|
||||||
],
|
],
|
||||||
|
|
|
@ -323,3 +323,9 @@ class ChatCompletionResponseMessage(TypedDict, total=False):
|
||||||
content: Optional[str]
|
content: Optional[str]
|
||||||
tool_calls: List[ChatCompletionToolCallChunk]
|
tool_calls: List[ChatCompletionToolCallChunk]
|
||||||
role: Literal["assistant"]
|
role: Literal["assistant"]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionUsageBlock(TypedDict):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing_extensions import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
Required,
|
Required,
|
||||||
)
|
)
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class Field(TypedDict):
|
class Field(TypedDict):
|
||||||
|
@ -48,6 +49,190 @@ class PartType(TypedDict, total=False):
|
||||||
function_response: FunctionResponse
|
function_response: FunctionResponse
|
||||||
|
|
||||||
|
|
||||||
|
class HttpxFunctionCall(TypedDict):
|
||||||
|
name: str
|
||||||
|
args: dict
|
||||||
|
|
||||||
|
|
||||||
|
class HttpxPartType(TypedDict, total=False):
|
||||||
|
text: str
|
||||||
|
inline_data: BlobType
|
||||||
|
file_data: FileDataType
|
||||||
|
functionCall: HttpxFunctionCall
|
||||||
|
function_response: FunctionResponse
|
||||||
|
|
||||||
|
|
||||||
|
class HttpxContentType(TypedDict, total=False):
|
||||||
|
role: Literal["user", "model"]
|
||||||
|
parts: Required[List[HttpxPartType]]
|
||||||
|
|
||||||
|
|
||||||
class ContentType(TypedDict, total=False):
|
class ContentType(TypedDict, total=False):
|
||||||
role: Literal["user", "model"]
|
role: Literal["user", "model"]
|
||||||
parts: Required[List[PartType]]
|
parts: Required[List[PartType]]
|
||||||
|
|
||||||
|
|
||||||
|
class SystemInstructions(TypedDict):
|
||||||
|
parts: Required[List[PartType]]
|
||||||
|
|
||||||
|
|
||||||
|
class Schema(TypedDict, total=False):
|
||||||
|
type: Literal["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"]
|
||||||
|
description: str
|
||||||
|
enum: List[str]
|
||||||
|
items: List["Schema"]
|
||||||
|
properties: "Schema"
|
||||||
|
required: List[str]
|
||||||
|
nullable: bool
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionDeclaration(TypedDict, total=False):
|
||||||
|
name: Required[str]
|
||||||
|
description: str
|
||||||
|
parameters: Schema
|
||||||
|
response: Schema
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCallingConfig(TypedDict, total=False):
|
||||||
|
mode: Literal["ANY", "AUTO", "NONE"]
|
||||||
|
allowed_function_names: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
HarmCategory = Literal[
|
||||||
|
"HARM_CATEGORY_UNSPECIFIED",
|
||||||
|
"HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"HARM_CATEGORY_HARASSMENT",
|
||||||
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
]
|
||||||
|
HarmBlockThreshold = Literal[
|
||||||
|
"HARM_BLOCK_THRESHOLD_UNSPECIFIED",
|
||||||
|
"BLOCK_LOW_AND_ABOVE",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"BLOCK_ONLY_HIGH",
|
||||||
|
"BLOCK_NONE",
|
||||||
|
]
|
||||||
|
HarmBlockMethod = Literal["HARM_BLOCK_METHOD_UNSPECIFIED", "SEVERITY", "PROBABILITY"]
|
||||||
|
|
||||||
|
HarmProbability = Literal[
|
||||||
|
"HARM_PROBABILITY_UNSPECIFIED", "NEGLIGIBLE", "LOW", "MEDIUM", "HIGH"
|
||||||
|
]
|
||||||
|
|
||||||
|
HarmSeverity = Literal[
|
||||||
|
"HARM_SEVERITY_UNSPECIFIED",
|
||||||
|
"HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"HARM_SEVERITY_LOW",
|
||||||
|
"HARM_SEVERITY_MEDIUM",
|
||||||
|
"HARM_SEVERITY_HIGH",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SafetSettingsConfig(TypedDict, total=False):
|
||||||
|
category: HarmCategory
|
||||||
|
threshold: HarmBlockThreshold
|
||||||
|
max_influential_terms: int
|
||||||
|
method: HarmBlockMethod
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationConfig(TypedDict, total=False):
|
||||||
|
temperature: float
|
||||||
|
top_p: float
|
||||||
|
top_k: float
|
||||||
|
candidate_count: int
|
||||||
|
max_output_tokens: int
|
||||||
|
stop_sequences: List[str]
|
||||||
|
presence_penalty: float
|
||||||
|
frequency_penalty: float
|
||||||
|
response_mime_type: Literal["text/plain", "application/json"]
|
||||||
|
|
||||||
|
|
||||||
|
class Tools(TypedDict):
|
||||||
|
function_declarations: List[FunctionDeclaration]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolConfig(TypedDict):
|
||||||
|
functionCallingConfig: FunctionCallingConfig
|
||||||
|
|
||||||
|
|
||||||
|
class RequestBody(TypedDict, total=False):
|
||||||
|
contents: Required[List[ContentType]]
|
||||||
|
system_instruction: SystemInstructions
|
||||||
|
tools: Tools
|
||||||
|
toolConfig: ToolConfig
|
||||||
|
safetySettings: SafetSettingsConfig
|
||||||
|
generationConfig: GenerationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyRatings(TypedDict):
|
||||||
|
category: HarmCategory
|
||||||
|
probability: HarmProbability
|
||||||
|
probabilityScore: int
|
||||||
|
severity: HarmSeverity
|
||||||
|
blocked: bool
|
||||||
|
|
||||||
|
|
||||||
|
class Date(TypedDict):
|
||||||
|
year: int
|
||||||
|
month: int
|
||||||
|
date: int
|
||||||
|
|
||||||
|
|
||||||
|
class Citation(TypedDict):
|
||||||
|
startIndex: int
|
||||||
|
endIndex: int
|
||||||
|
uri: str
|
||||||
|
title: str
|
||||||
|
license: str
|
||||||
|
publicationDate: Date
|
||||||
|
|
||||||
|
|
||||||
|
class CitationMetadata(TypedDict):
|
||||||
|
citations: List[Citation]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchEntryPoint(TypedDict, total=False):
|
||||||
|
renderedContent: str
|
||||||
|
sdkBlob: str
|
||||||
|
|
||||||
|
|
||||||
|
class GroundingMetadata(TypedDict, total=False):
|
||||||
|
webSearchQueries: List[str]
|
||||||
|
searchEntryPoint: SearchEntryPoint
|
||||||
|
|
||||||
|
|
||||||
|
class Candidates(TypedDict, total=False):
|
||||||
|
index: int
|
||||||
|
content: HttpxContentType
|
||||||
|
finishReason: Literal[
|
||||||
|
"FINISH_REASON_UNSPECIFIED",
|
||||||
|
"STOP",
|
||||||
|
"MAX_TOKENS",
|
||||||
|
"SAFETY",
|
||||||
|
"RECITATION",
|
||||||
|
"OTHER",
|
||||||
|
"BLOCKLIST",
|
||||||
|
"PROHIBITED_CONTENT",
|
||||||
|
"SPII",
|
||||||
|
]
|
||||||
|
safetyRatings: SafetyRatings
|
||||||
|
citationMetadata: CitationMetadata
|
||||||
|
groundingMetadata: GroundingMetadata
|
||||||
|
finishMessage: str
|
||||||
|
|
||||||
|
|
||||||
|
class PromptFeedback(TypedDict):
|
||||||
|
blockReason: str
|
||||||
|
safetyRatings: List[SafetyRatings]
|
||||||
|
blockReasonMessage: str
|
||||||
|
|
||||||
|
|
||||||
|
class UsageMetadata(TypedDict):
|
||||||
|
promptTokenCount: int
|
||||||
|
totalTokenCount: int
|
||||||
|
candidatesTokenCount: int
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateContentResponseBody(TypedDict, total=False):
|
||||||
|
candidates: Required[List[Candidates]]
|
||||||
|
promptFeedback: PromptFeedback
|
||||||
|
usageMetadata: Required[UsageMetadata]
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from typing import List, Optional, Union, Dict, Tuple, Literal
|
from typing import List, Optional, Union, Dict, Tuple, Literal
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing_extensions import override, Required, Dict
|
||||||
|
from .llms.openai import ChatCompletionUsageBlock, ChatCompletionToolCallChunk
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMCommonStrings(Enum):
|
class LiteLLMCommonStrings(Enum):
|
||||||
|
@ -37,3 +39,12 @@ class ModelInfo(TypedDict):
|
||||||
"completion", "embedding", "image_generation", "chat", "audio_transcription"
|
"completion", "embedding", "image_generation", "chat", "audio_transcription"
|
||||||
]
|
]
|
||||||
supported_openai_params: Optional[List[str]]
|
supported_openai_params: Optional[List[str]]
|
||||||
|
|
||||||
|
|
||||||
|
class GenericStreamingChunk(TypedDict):
|
||||||
|
text: Required[str]
|
||||||
|
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||||
|
is_finished: Required[bool]
|
||||||
|
finish_reason: Required[str]
|
||||||
|
usage: Optional[ChatCompletionUsageBlock]
|
||||||
|
index: int
|
||||||
|
|
|
@ -518,15 +518,18 @@ class Choices(OpenAIObject):
|
||||||
self,
|
self,
|
||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
index=0,
|
index=0,
|
||||||
message=None,
|
message: Optional[Union[Message, dict]] = None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
enhancements=None,
|
enhancements=None,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
super(Choices, self).__init__(**params)
|
super(Choices, self).__init__(**params)
|
||||||
self.finish_reason = (
|
if finish_reason is not None:
|
||||||
map_finish_reason(finish_reason) or "stop"
|
self.finish_reason = map_finish_reason(
|
||||||
) # set finish_reason for all responses
|
finish_reason
|
||||||
|
) # set finish_reason for all responses
|
||||||
|
else:
|
||||||
|
self.finish_reason = "stop"
|
||||||
self.index = index
|
self.index = index
|
||||||
if message is None:
|
if message is None:
|
||||||
self.message = Message()
|
self.message = Message()
|
||||||
|
@ -2822,7 +2825,9 @@ class Rules:
|
||||||
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore
|
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def post_call_rules(self, input: str, model: str):
|
def post_call_rules(self, input: Optional[str], model: str) -> bool:
|
||||||
|
if input is None:
|
||||||
|
return True
|
||||||
for rule in litellm.post_call_rules:
|
for rule in litellm.post_call_rules:
|
||||||
if callable(rule):
|
if callable(rule):
|
||||||
decision = rule(input)
|
decision = rule(input)
|
||||||
|
@ -3101,9 +3106,9 @@ def client(original_function):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(original_response, ModelResponse):
|
if isinstance(original_response, ModelResponse):
|
||||||
model_response = original_response["choices"][0]["message"][
|
model_response = original_response.choices[
|
||||||
"content"
|
0
|
||||||
]
|
].message.content
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
rules_obj.post_call_rules(input=model_response, model=model)
|
rules_obj.post_call_rules(input=model_response, model=model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -5404,6 +5409,16 @@ def get_optional_params(
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
|
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "vertex_ai_beta":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.VertexGeminiConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
|
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
|
||||||
):
|
):
|
||||||
|
@ -11258,6 +11273,34 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion_obj["content"] = str(chunk)
|
completion_obj["content"] = str(chunk)
|
||||||
|
elif self.custom_llm_provider and (
|
||||||
|
self.custom_llm_provider == "vertex_ai_beta"
|
||||||
|
):
|
||||||
|
from litellm.types.utils import (
|
||||||
|
GenericStreamingChunk as UtilsStreamingChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.received_finish_reason is not None:
|
||||||
|
raise StopIteration
|
||||||
|
response_obj: UtilsStreamingChunk = chunk
|
||||||
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
if response_obj["is_finished"]:
|
||||||
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.stream_options
|
||||||
|
and self.stream_options.get("include_usage", False) is True
|
||||||
|
and response_obj["usage"] is not None
|
||||||
|
):
|
||||||
|
self.sent_stream_usage = True
|
||||||
|
model_response.usage = litellm.Usage(
|
||||||
|
prompt_tokens=response_obj["usage"]["prompt_tokens"],
|
||||||
|
completion_tokens=response_obj["usage"]["completion_tokens"],
|
||||||
|
total_tokens=response_obj["usage"]["total_tokens"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
|
||||||
|
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||||
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
|
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
|
||||||
import proto # type: ignore
|
import proto # type: ignore
|
||||||
|
|
||||||
|
@ -11935,6 +11978,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "ollama"
|
or self.custom_llm_provider == "ollama"
|
||||||
or self.custom_llm_provider == "ollama_chat"
|
or self.custom_llm_provider == "ollama_chat"
|
||||||
or self.custom_llm_provider == "vertex_ai"
|
or self.custom_llm_provider == "vertex_ai"
|
||||||
|
or self.custom_llm_provider == "vertex_ai_beta"
|
||||||
or self.custom_llm_provider == "sagemaker"
|
or self.custom_llm_provider == "sagemaker"
|
||||||
or self.custom_llm_provider == "gemini"
|
or self.custom_llm_provider == "gemini"
|
||||||
or self.custom_llm_provider == "replicate"
|
or self.custom_llm_provider == "replicate"
|
||||||
|
|
10
log.txt
Normal file
10
log.txt
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
============================= test session starts ==============================
|
||||||
|
platform darwin -- Python 3.11.4, pytest-8.2.0, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/litellm/proxy/myenv/bin/python3.11
|
||||||
|
cachedir: .pytest_cache
|
||||||
|
rootdir: /Users/krrishdholakia/Documents/litellm
|
||||||
|
configfile: pyproject.toml
|
||||||
|
plugins: logfire-0.35.0, asyncio-0.23.6, mock-3.14.0, anyio-4.2.0
|
||||||
|
asyncio: mode=Mode.STRICT
|
||||||
|
collecting ... collected 0 items
|
||||||
|
|
||||||
|
============================ no tests ran in 0.00s =============================
|
Loading…
Add table
Add a link
Reference in a new issue