mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(vertex_ai_partner.py): initial working commit for calling vertex ai mistral
Closes https://github.com/BerriAI/litellm/issues/4874
This commit is contained in:
parent
1a8f45e8da
commit
5b71421a7b
10 changed files with 343 additions and 140 deletions
|
@ -358,6 +358,7 @@ vertex_code_text_models: List = []
|
||||||
vertex_embedding_models: List = []
|
vertex_embedding_models: List = []
|
||||||
vertex_anthropic_models: List = []
|
vertex_anthropic_models: List = []
|
||||||
vertex_llama3_models: List = []
|
vertex_llama3_models: List = []
|
||||||
|
vertex_mistral_models: List = []
|
||||||
ai21_models: List = []
|
ai21_models: List = []
|
||||||
nlp_cloud_models: List = []
|
nlp_cloud_models: List = []
|
||||||
aleph_alpha_models: List = []
|
aleph_alpha_models: List = []
|
||||||
|
@ -403,6 +404,9 @@ for key, value in model_cost.items():
|
||||||
elif value.get("litellm_provider") == "vertex_ai-llama_models":
|
elif value.get("litellm_provider") == "vertex_ai-llama_models":
|
||||||
key = key.replace("vertex_ai/", "")
|
key = key.replace("vertex_ai/", "")
|
||||||
vertex_llama3_models.append(key)
|
vertex_llama3_models.append(key)
|
||||||
|
elif value.get("litellm_provider") == "vertex_ai-mistral_models":
|
||||||
|
key = key.replace("vertex_ai/", "")
|
||||||
|
vertex_mistral_models.append(key)
|
||||||
elif value.get("litellm_provider") == "ai21":
|
elif value.get("litellm_provider") == "ai21":
|
||||||
ai21_models.append(key)
|
ai21_models.append(key)
|
||||||
elif value.get("litellm_provider") == "nlp_cloud":
|
elif value.get("litellm_provider") == "nlp_cloud":
|
||||||
|
@ -833,7 +837,7 @@ from .llms.petals import PetalsConfig
|
||||||
from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig
|
from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig
|
||||||
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
|
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
|
||||||
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
||||||
from .llms.vertex_ai_llama import VertexAILlama3Config
|
from .llms.vertex_ai_partner import VertexAILlama3Config
|
||||||
from .llms.sagemaker import SagemakerConfig
|
from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
|
|
|
@ -15,8 +15,14 @@ import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.types.llms.databricks import GenericStreamingChunk
|
from litellm.types.llms.openai import (
|
||||||
from litellm.types.utils import ProviderField
|
ChatCompletionDeltaChunk,
|
||||||
|
ChatCompletionResponseMessage,
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import GenericStreamingChunk, ProviderField
|
||||||
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
||||||
|
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
@ -114,71 +120,6 @@ class DatabricksConfig:
|
||||||
optional_params["stop"] = value
|
optional_params["stop"] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
|
|
||||||
try:
|
|
||||||
text = ""
|
|
||||||
is_finished = False
|
|
||||||
finish_reason = None
|
|
||||||
logprobs = None
|
|
||||||
usage = None
|
|
||||||
original_chunk = None # this is used for function/tool calling
|
|
||||||
chunk_data = chunk_data.replace("data:", "")
|
|
||||||
chunk_data = chunk_data.strip()
|
|
||||||
if len(chunk_data) == 0 or chunk_data == "[DONE]":
|
|
||||||
return {
|
|
||||||
"text": "",
|
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
chunk_data_dict = json.loads(chunk_data)
|
|
||||||
str_line = litellm.ModelResponse(**chunk_data_dict, stream=True)
|
|
||||||
|
|
||||||
if len(str_line.choices) > 0:
|
|
||||||
if (
|
|
||||||
str_line.choices[0].delta is not None # type: ignore
|
|
||||||
and str_line.choices[0].delta.content is not None # type: ignore
|
|
||||||
):
|
|
||||||
text = str_line.choices[0].delta.content # type: ignore
|
|
||||||
else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai
|
|
||||||
original_chunk = str_line
|
|
||||||
if str_line.choices[0].finish_reason:
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = str_line.choices[0].finish_reason
|
|
||||||
if finish_reason == "content_filter":
|
|
||||||
if hasattr(str_line.choices[0], "content_filter_result"):
|
|
||||||
error_message = json.dumps(
|
|
||||||
str_line.choices[0].content_filter_result # type: ignore
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
error_message = "Azure Response={}".format(
|
|
||||||
str(dict(str_line))
|
|
||||||
)
|
|
||||||
raise litellm.AzureOpenAIError(
|
|
||||||
status_code=400, message=error_message
|
|
||||||
)
|
|
||||||
|
|
||||||
# checking for logprobs
|
|
||||||
if (
|
|
||||||
hasattr(str_line.choices[0], "logprobs")
|
|
||||||
and str_line.choices[0].logprobs is not None
|
|
||||||
):
|
|
||||||
logprobs = str_line.choices[0].logprobs
|
|
||||||
else:
|
|
||||||
logprobs = None
|
|
||||||
|
|
||||||
usage = getattr(str_line, "usage", None)
|
|
||||||
|
|
||||||
return GenericStreamingChunk(
|
|
||||||
text=text,
|
|
||||||
is_finished=is_finished,
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
logprobs=logprobs,
|
|
||||||
original_chunk=original_chunk,
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
class DatabricksEmbeddingConfig:
|
class DatabricksEmbeddingConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -236,7 +177,9 @@ async def make_call(
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise DatabricksError(status_code=response.status_code, message=response.text)
|
raise DatabricksError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
completion_stream = response.aiter_lines()
|
completion_stream = ModelResponseIterator(
|
||||||
|
streaming_response=response.aiter_lines(), sync_stream=False
|
||||||
|
)
|
||||||
# LOGGING
|
# LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -248,6 +191,38 @@ async def make_call(
|
||||||
return completion_stream
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
|
def make_sync_call(
|
||||||
|
client: Optional[HTTPHandler],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
if client is None:
|
||||||
|
client = HTTPHandler() # Create a new client if none provided
|
||||||
|
|
||||||
|
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise DatabricksError(status_code=response.status_code, message=response.read())
|
||||||
|
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
streaming_response=response.iter_lines(), sync_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response="first stream response received",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class DatabricksChatCompletion(BaseLLM):
|
class DatabricksChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -259,6 +234,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
endpoint_type: Literal["chat_completions", "embeddings"],
|
endpoint_type: Literal["chat_completions", "embeddings"],
|
||||||
|
custom_endpoint: Optional[bool],
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise DatabricksError(
|
raise DatabricksError(
|
||||||
|
@ -277,9 +253,9 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
if endpoint_type == "chat_completions":
|
if endpoint_type == "chat_completions" and custom_endpoint is not True:
|
||||||
api_base = "{}/chat/completions".format(api_base)
|
api_base = "{}/chat/completions".format(api_base)
|
||||||
elif endpoint_type == "embeddings":
|
elif endpoint_type == "embeddings" and custom_endpoint is not True:
|
||||||
api_base = "{}/embeddings".format(api_base)
|
api_base = "{}/embeddings".format(api_base)
|
||||||
return api_base, headers
|
return api_base, headers
|
||||||
|
|
||||||
|
@ -464,8 +440,12 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
|
custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None)
|
||||||
api_base, headers = self._validate_environment(
|
api_base, headers = self._validate_environment(
|
||||||
api_base=api_base, api_key=api_key, endpoint_type="chat_completions"
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
endpoint_type="chat_completions",
|
||||||
|
custom_endpoint=custom_endpoint,
|
||||||
)
|
)
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.DatabricksConfig().get_config()
|
config = litellm.DatabricksConfig().get_config()
|
||||||
|
@ -475,7 +455,8 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
stream = optional_params.pop("stream", None)
|
stream: bool = optional_params.pop("stream", None) or False
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
|
@ -539,41 +520,26 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
self.client = HTTPHandler(timeout=timeout) # type: ignore
|
client = HTTPHandler(timeout=timeout) # type: ignore
|
||||||
else:
|
|
||||||
self.client = client
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if (
|
if stream is True:
|
||||||
stream is not None and stream == True
|
return CustomStreamWrapper(
|
||||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
completion_stream=None,
|
||||||
print_verbose("makes dbrx streaming POST request")
|
make_call=partial(
|
||||||
data["stream"] = stream
|
make_sync_call,
|
||||||
try:
|
client=None,
|
||||||
response = self.client.post(
|
api_base=api_base,
|
||||||
api_base, headers=headers, data=json.dumps(data), stream=stream
|
headers=headers, # type: ignore
|
||||||
)
|
data=json.dumps(data),
|
||||||
response.raise_for_status()
|
model=model,
|
||||||
completion_stream = response.iter_lines()
|
messages=messages,
|
||||||
except httpx.HTTPStatusError as e:
|
logging_obj=logging_obj,
|
||||||
raise DatabricksError(
|
),
|
||||||
status_code=e.response.status_code, message=response.text
|
|
||||||
)
|
|
||||||
except httpx.TimeoutException as e:
|
|
||||||
raise DatabricksError(
|
|
||||||
status_code=408, message="Timeout error occurred."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise DatabricksError(status_code=408, message=str(e))
|
|
||||||
|
|
||||||
streaming_response = CustomStreamWrapper(
|
|
||||||
completion_stream=completion_stream,
|
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="databricks",
|
custom_llm_provider="vertex_ai_beta",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
return streaming_response
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
|
@ -667,7 +633,10 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
api_base, headers = self._validate_environment(
|
api_base, headers = self._validate_environment(
|
||||||
api_base=api_base, api_key=api_key, endpoint_type="embeddings"
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
endpoint_type="embeddings",
|
||||||
|
custom_endpoint=False,
|
||||||
)
|
)
|
||||||
model = model
|
model = model
|
||||||
data = {"model": model, "input": input, **optional_params}
|
data = {"model": model, "input": input, **optional_params}
|
||||||
|
@ -716,3 +685,126 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
return litellm.EmbeddingResponse(**response_json)
|
return litellm.EmbeddingResponse(**response_json)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelResponseIterator:
|
||||||
|
def __init__(self, streaming_response, sync_stream: bool):
|
||||||
|
self.streaming_response = streaming_response
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
|
|
||||||
|
if processed_chunk.choices[0].delta.content is not None: # type: ignore
|
||||||
|
text = processed_chunk.choices[0].delta.content # type: ignore
|
||||||
|
|
||||||
|
if (
|
||||||
|
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
|
||||||
|
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
|
||||||
|
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
|
||||||
|
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
|
||||||
|
is not None
|
||||||
|
):
|
||||||
|
tool_use = ChatCompletionToolCallChunk(
|
||||||
|
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolCallFunctionChunk(
|
||||||
|
name=processed_chunk.choices[0]
|
||||||
|
.delta.tool_calls[0] # type: ignore
|
||||||
|
.function.name,
|
||||||
|
arguments=processed_chunk.choices[0]
|
||||||
|
.delta.tool_calls[0] # type: ignore
|
||||||
|
.function.arguments,
|
||||||
|
),
|
||||||
|
index=processed_chunk.choices[0].index,
|
||||||
|
)
|
||||||
|
|
||||||
|
if processed_chunk.choices[0].finish_reason is not None:
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = processed_chunk.choices[0].finish_reason
|
||||||
|
|
||||||
|
if hasattr(processed_chunk, "usage"):
|
||||||
|
usage = processed_chunk.usage # type: ignore
|
||||||
|
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_use=tool_use,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
self.response_iterator = self.streaming_response
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
try:
|
||||||
|
chunk = self.response_iterator.__next__()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk = chunk.replace("data:", "")
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if len(chunk) > 0:
|
||||||
|
json_chunk = json.loads(chunk)
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
else:
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk = chunk.replace("data:", "")
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if len(chunk) > 0:
|
||||||
|
json_chunk = json.loads(chunk)
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
else:
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
|
@ -160,7 +160,7 @@ class MistralConfig:
|
||||||
optional_params["max_tokens"] = value
|
optional_params["max_tokens"] = value
|
||||||
if param == "tools":
|
if param == "tools":
|
||||||
optional_params["tools"] = value
|
optional_params["tools"] = value
|
||||||
if param == "stream" and value == True:
|
if param == "stream" and value is True:
|
||||||
optional_params["stream"] = value
|
optional_params["stream"] = value
|
||||||
if param == "temperature":
|
if param == "temperature":
|
||||||
optional_params["temperature"] = value
|
optional_params["temperature"] = value
|
||||||
|
|
|
@ -7,7 +7,7 @@ import time
|
||||||
import types
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
@ -108,14 +108,25 @@ class VertexAILlama3Config:
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class VertexAILlama3(BaseLLM):
|
class VertexAIPartnerModels(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_vertex_llama3_url(
|
def create_vertex_url(
|
||||||
self, vertex_location: str, vertex_project: str
|
self,
|
||||||
|
vertex_location: str,
|
||||||
|
vertex_project: str,
|
||||||
|
partner: Literal["llama", "mistralai"],
|
||||||
|
stream: Optional[bool],
|
||||||
|
model: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
|
if partner == "llama":
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
|
||||||
|
elif partner == "mistralai":
|
||||||
|
if stream:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
|
||||||
|
else:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -141,6 +152,7 @@ class VertexAILlama3(BaseLLM):
|
||||||
import vertexai
|
import vertexai
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
|
from litellm.llms.databricks import DatabricksChatCompletion
|
||||||
from litellm.llms.openai import OpenAIChatCompletion
|
from litellm.llms.openai import OpenAIChatCompletion
|
||||||
from litellm.llms.vertex_httpx import VertexLLM
|
from litellm.llms.vertex_httpx import VertexLLM
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -166,6 +178,7 @@ class VertexAILlama3(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_chat_completions = OpenAIChatCompletion()
|
openai_chat_completions = OpenAIChatCompletion()
|
||||||
|
openai_like_chat_completions = DatabricksChatCompletion()
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
# config = litellm.VertexAILlama3.get_config()
|
# config = litellm.VertexAILlama3.get_config()
|
||||||
|
@ -178,12 +191,23 @@ class VertexAILlama3(BaseLLM):
|
||||||
|
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
|
|
||||||
api_base = self.create_vertex_llama3_url(
|
if "llama" in model:
|
||||||
|
partner = "llama"
|
||||||
|
elif "mistral" in model:
|
||||||
|
partner = "mistralai"
|
||||||
|
optional_params["custom_endpoint"] = True
|
||||||
|
|
||||||
|
api_base = self.create_vertex_url(
|
||||||
vertex_location=vertex_location or "us-central1",
|
vertex_location=vertex_location or "us-central1",
|
||||||
vertex_project=vertex_project or project_id,
|
vertex_project=vertex_project or project_id,
|
||||||
|
partner=partner, # type: ignore
|
||||||
|
stream=stream,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai_chat_completions.completion(
|
model = model.split("@")[0]
|
||||||
|
|
||||||
|
return openai_like_chat_completions.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -198,6 +222,7 @@ class VertexAILlama3(BaseLLM):
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
client=client,
|
client=client,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
|
@ -121,7 +121,7 @@ from .llms.prompt_templates.factory import (
|
||||||
)
|
)
|
||||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.vertex_ai_llama import VertexAILlama3
|
from .llms.vertex_ai_partner import VertexAIPartnerModels
|
||||||
from .llms.vertex_httpx import VertexLLM
|
from .llms.vertex_httpx import VertexLLM
|
||||||
from .llms.watsonx import IBMWatsonXAI
|
from .llms.watsonx import IBMWatsonXAI
|
||||||
from .types.llms.openai import HttpxBinaryResponseContent
|
from .types.llms.openai import HttpxBinaryResponseContent
|
||||||
|
@ -158,7 +158,7 @@ triton_chat_completions = TritonChatCompletion()
|
||||||
bedrock_chat_completion = BedrockLLM()
|
bedrock_chat_completion = BedrockLLM()
|
||||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
vertex_llama_chat_completion = VertexAILlama3()
|
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||||
watsonxai = IBMWatsonXAI()
|
watsonxai = IBMWatsonXAI()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
@ -2068,8 +2068,8 @@ def completion(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
elif model.startswith("meta/"):
|
elif model.startswith("meta/") or model.startswith("mistral"):
|
||||||
model_response = vertex_llama_chat_completion.completion(
|
model_response = vertex_partner_models_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
|
|
@ -2028,6 +2028,16 @@
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
|
||||||
},
|
},
|
||||||
|
"vertex_ai/mistral-large@latest": {
|
||||||
|
"max_tokens": 8191,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 8191,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000009,
|
||||||
|
"litellm_provider": "vertex_ai-mistral_models",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true
|
||||||
|
},
|
||||||
"vertex_ai/mistral-large@2407": {
|
"vertex_ai/mistral-large@2407": {
|
||||||
"max_tokens": 8191,
|
"max_tokens": 8191,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
|
|
@ -899,16 +899,18 @@ from litellm.tests.test_completion import response_format_tests
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model", ["vertex_ai/meta/llama3-405b-instruct-maas"]
|
"model",
|
||||||
|
[
|
||||||
|
"vertex_ai/mistral-large@2407",
|
||||||
|
"vertex_ai/meta/llama3-405b-instruct-maas",
|
||||||
|
], #
|
||||||
) # "vertex_ai",
|
) # "vertex_ai",
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sync_mode",
|
"sync_mode",
|
||||||
[
|
[True, False],
|
||||||
True,
|
) #
|
||||||
],
|
|
||||||
) # False
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llama_3_httpx(model, sync_mode):
|
async def test_partner_models_httpx(model, sync_mode):
|
||||||
try:
|
try:
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -946,6 +948,57 @@ async def test_llama_3_httpx(model, sync_mode):
|
||||||
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
|
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"vertex_ai/mistral-large@2407",
|
||||||
|
"vertex_ai/meta/llama3-405b-instruct-maas",
|
||||||
|
], #
|
||||||
|
) # "vertex_ai",
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sync_mode",
|
||||||
|
[True, False], #
|
||||||
|
) #
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partner_models_httpx_streaming(model, sync_mode):
|
||||||
|
try:
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||||
|
},
|
||||||
|
# User asks for their name and weather in San Francisco
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, what is your name and can you tell me the weather?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
data = {"model": model, "messages": messages, "stream": True}
|
||||||
|
if sync_mode:
|
||||||
|
response = litellm.completion(**data)
|
||||||
|
for idx, chunk in enumerate(response):
|
||||||
|
streaming_format_tests(idx=idx, chunk=chunk)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(**data)
|
||||||
|
idx = 0
|
||||||
|
async for chunk in response:
|
||||||
|
streaming_format_tests(idx=idx, chunk=chunk)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
except litellm.RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
if "429 Quota exceeded" in str(e):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
|
||||||
|
|
||||||
|
|
||||||
def vertex_httpx_mock_reject_prompt_post(*args, **kwargs):
|
def vertex_httpx_mock_reject_prompt_post(*args, **kwargs):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
|
|
|
@ -141,6 +141,21 @@ def test_vertex_ai_llama_3_optional_params():
|
||||||
assert "user" not in optional_params
|
assert "user" not in optional_params
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertex_ai_mistral_optional_params():
|
||||||
|
litellm.vertex_mistral_models = ["mistral-large@2407"]
|
||||||
|
litellm.drop_params = True
|
||||||
|
optional_params = get_optional_params(
|
||||||
|
model="mistral-large@2407",
|
||||||
|
user="John",
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.2,
|
||||||
|
)
|
||||||
|
assert "user" not in optional_params
|
||||||
|
assert "max_tokens" in optional_params
|
||||||
|
assert "temperature" in optional_params
|
||||||
|
|
||||||
|
|
||||||
def test_azure_gpt_optional_params_gpt_vision():
|
def test_azure_gpt_optional_params_gpt_vision():
|
||||||
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
|
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
|
||||||
optional_params = litellm.utils.get_optional_params(
|
optional_params = litellm.utils.get_optional_params(
|
||||||
|
|
|
@ -3104,6 +3104,15 @@ def get_optional_params(
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_mistral_models:
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.MistralConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -4210,7 +4219,8 @@ def get_supported_openai_params(
|
||||||
if request_type == "chat_completion":
|
if request_type == "chat_completion":
|
||||||
if model.startswith("meta/"):
|
if model.startswith("meta/"):
|
||||||
return litellm.VertexAILlama3Config().get_supported_openai_params()
|
return litellm.VertexAILlama3Config().get_supported_openai_params()
|
||||||
|
if model.startswith("mistral"):
|
||||||
|
return litellm.MistralConfig().get_supported_openai_params()
|
||||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||||
elif request_type == "embeddings":
|
elif request_type == "embeddings":
|
||||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||||
|
@ -9631,22 +9641,6 @@ class CustomStreamWrapper:
|
||||||
completion_tokens=response_obj["usage"].completion_tokens,
|
completion_tokens=response_obj["usage"].completion_tokens,
|
||||||
total_tokens=response_obj["usage"].total_tokens,
|
total_tokens=response_obj["usage"].total_tokens,
|
||||||
)
|
)
|
||||||
elif self.custom_llm_provider == "databricks":
|
|
||||||
response_obj = litellm.DatabricksConfig()._chunk_parser(chunk)
|
|
||||||
completion_obj["content"] = response_obj["text"]
|
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
|
||||||
if response_obj["is_finished"]:
|
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
|
||||||
if (
|
|
||||||
self.stream_options
|
|
||||||
and self.stream_options.get("include_usage", False) == True
|
|
||||||
and response_obj["usage"] is not None
|
|
||||||
):
|
|
||||||
model_response.usage = litellm.Usage(
|
|
||||||
prompt_tokens=response_obj["usage"].prompt_tokens,
|
|
||||||
completion_tokens=response_obj["usage"].completion_tokens,
|
|
||||||
total_tokens=response_obj["usage"].total_tokens,
|
|
||||||
)
|
|
||||||
elif self.custom_llm_provider == "azure_text":
|
elif self.custom_llm_provider == "azure_text":
|
||||||
response_obj = self.handle_azure_text_completion_chunk(chunk)
|
response_obj = self.handle_azure_text_completion_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
|
|
@ -2028,6 +2028,16 @@
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
|
||||||
},
|
},
|
||||||
|
"vertex_ai/mistral-large@latest": {
|
||||||
|
"max_tokens": 8191,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 8191,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000009,
|
||||||
|
"litellm_provider": "vertex_ai-mistral_models",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true
|
||||||
|
},
|
||||||
"vertex_ai/mistral-large@2407": {
|
"vertex_ai/mistral-large@2407": {
|
||||||
"max_tokens": 8191,
|
"max_tokens": 8191,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue