Merge pull request #4925 from BerriAI/litellm_vertex_mistral

feat(vertex_ai_partner.py): Vertex AI Mistral Support
This commit is contained in:
Krish Dholakia 2024-07-27 21:51:26 -07:00 committed by GitHub
commit e3a94ac013
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 365 additions and 147 deletions

View file

@ -358,6 +358,7 @@ vertex_code_text_models: List = []
vertex_embedding_models: List = [] vertex_embedding_models: List = []
vertex_anthropic_models: List = [] vertex_anthropic_models: List = []
vertex_llama3_models: List = [] vertex_llama3_models: List = []
vertex_mistral_models: List = []
ai21_models: List = [] ai21_models: List = []
nlp_cloud_models: List = [] nlp_cloud_models: List = []
aleph_alpha_models: List = [] aleph_alpha_models: List = []
@ -403,6 +404,9 @@ for key, value in model_cost.items():
elif value.get("litellm_provider") == "vertex_ai-llama_models": elif value.get("litellm_provider") == "vertex_ai-llama_models":
key = key.replace("vertex_ai/", "") key = key.replace("vertex_ai/", "")
vertex_llama3_models.append(key) vertex_llama3_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-mistral_models":
key = key.replace("vertex_ai/", "")
vertex_mistral_models.append(key)
elif value.get("litellm_provider") == "ai21": elif value.get("litellm_provider") == "ai21":
ai21_models.append(key) ai21_models.append(key)
elif value.get("litellm_provider") == "nlp_cloud": elif value.get("litellm_provider") == "nlp_cloud":
@ -833,7 +837,7 @@ from .llms.petals import PetalsConfig
from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
from .llms.vertex_ai_llama import VertexAILlama3Config from .llms.vertex_ai_partner import VertexAILlama3Config
from .llms.sagemaker import SagemakerConfig from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig

View file

@ -15,8 +15,14 @@ import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk from litellm.types.llms.openai import (
from litellm.types.utils import ProviderField ChatCompletionDeltaChunk,
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk, ProviderField
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM from .base import BaseLLM
@ -114,71 +120,6 @@ class DatabricksConfig:
optional_params["stop"] = value optional_params["stop"] = value
return optional_params return optional_params
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
try:
text = ""
is_finished = False
finish_reason = None
logprobs = None
usage = None
original_chunk = None # this is used for function/tool calling
chunk_data = chunk_data.replace("data:", "")
chunk_data = chunk_data.strip()
if len(chunk_data) == 0 or chunk_data == "[DONE]":
return {
"text": "",
"is_finished": is_finished,
"finish_reason": finish_reason,
}
chunk_data_dict = json.loads(chunk_data)
str_line = litellm.ModelResponse(**chunk_data_dict, stream=True)
if len(str_line.choices) > 0:
if (
str_line.choices[0].delta is not None # type: ignore
and str_line.choices[0].delta.content is not None # type: ignore
):
text = str_line.choices[0].delta.content # type: ignore
else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai
original_chunk = str_line
if str_line.choices[0].finish_reason:
is_finished = True
finish_reason = str_line.choices[0].finish_reason
if finish_reason == "content_filter":
if hasattr(str_line.choices[0], "content_filter_result"):
error_message = json.dumps(
str_line.choices[0].content_filter_result # type: ignore
)
else:
error_message = "Azure Response={}".format(
str(dict(str_line))
)
raise litellm.AzureOpenAIError(
status_code=400, message=error_message
)
# checking for logprobs
if (
hasattr(str_line.choices[0], "logprobs")
and str_line.choices[0].logprobs is not None
):
logprobs = str_line.choices[0].logprobs
else:
logprobs = None
usage = getattr(str_line, "usage", None)
return GenericStreamingChunk(
text=text,
is_finished=is_finished,
finish_reason=finish_reason,
logprobs=logprobs,
original_chunk=original_chunk,
usage=usage,
)
except Exception as e:
raise e
class DatabricksEmbeddingConfig: class DatabricksEmbeddingConfig:
""" """
@ -236,7 +177,9 @@ async def make_call(
if response.status_code != 200: if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.text) raise DatabricksError(status_code=response.status_code, message=response.text)
completion_stream = response.aiter_lines() completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -248,6 +191,38 @@ async def make_call(
return completion_stream return completion_stream
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream
class DatabricksChatCompletion(BaseLLM): class DatabricksChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -259,6 +234,7 @@ class DatabricksChatCompletion(BaseLLM):
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
endpoint_type: Literal["chat_completions", "embeddings"], endpoint_type: Literal["chat_completions", "embeddings"],
custom_endpoint: Optional[bool],
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
if api_key is None: if api_key is None:
raise DatabricksError( raise DatabricksError(
@ -277,9 +253,9 @@ class DatabricksChatCompletion(BaseLLM):
"Content-Type": "application/json", "Content-Type": "application/json",
} }
if endpoint_type == "chat_completions": if endpoint_type == "chat_completions" and custom_endpoint is not True:
api_base = "{}/chat/completions".format(api_base) api_base = "{}/chat/completions".format(api_base)
elif endpoint_type == "embeddings": elif endpoint_type == "embeddings" and custom_endpoint is not True:
api_base = "{}/embeddings".format(api_base) api_base = "{}/embeddings".format(api_base)
return api_base, headers return api_base, headers
@ -368,6 +344,7 @@ class DatabricksChatCompletion(BaseLLM):
self, self,
model: str, model: str,
messages: list, messages: list,
custom_llm_provider: str,
api_base: str, api_base: str,
custom_prompt_dict: dict, custom_prompt_dict: dict,
model_response: ModelResponse, model_response: ModelResponse,
@ -397,7 +374,7 @@ class DatabricksChatCompletion(BaseLLM):
logging_obj=logging_obj, logging_obj=logging_obj,
), ),
model=model, model=model,
custom_llm_provider="databricks", custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj, logging_obj=logging_obj,
) )
return streamwrapper return streamwrapper
@ -450,6 +427,7 @@ class DatabricksChatCompletion(BaseLLM):
model: str, model: str,
messages: list, messages: list,
api_base: str, api_base: str,
custom_llm_provider: str,
custom_prompt_dict: dict, custom_prompt_dict: dict,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
@ -464,8 +442,12 @@ class DatabricksChatCompletion(BaseLLM):
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
): ):
custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None)
api_base, headers = self._validate_environment( api_base, headers = self._validate_environment(
api_base=api_base, api_key=api_key, endpoint_type="chat_completions" api_base=api_base,
api_key=api_key,
endpoint_type="chat_completions",
custom_endpoint=custom_endpoint,
) )
## Load Config ## Load Config
config = litellm.DatabricksConfig().get_config() config = litellm.DatabricksConfig().get_config()
@ -475,7 +457,8 @@ class DatabricksChatCompletion(BaseLLM):
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
stream = optional_params.pop("stream", None) stream: bool = optional_params.pop("stream", None) or False
optional_params["stream"] = stream
data = { data = {
"model": model, "model": model,
@ -518,6 +501,7 @@ class DatabricksChatCompletion(BaseLLM):
logger_fn=logger_fn, logger_fn=logger_fn,
headers=headers, headers=headers,
client=client, client=client,
custom_llm_provider=custom_llm_provider,
) )
else: else:
return self.acompletion_function( return self.acompletion_function(
@ -539,44 +523,29 @@ class DatabricksChatCompletion(BaseLLM):
timeout=timeout, timeout=timeout,
) )
else: else:
if client is None or isinstance(client, AsyncHTTPHandler): if client is None or not isinstance(client, HTTPHandler):
self.client = HTTPHandler(timeout=timeout) # type: ignore client = HTTPHandler(timeout=timeout) # type: ignore
else:
self.client = client
## COMPLETION CALL ## COMPLETION CALL
if ( if stream is True:
stream is not None and stream == True return CustomStreamWrapper(
): # if function call - fake the streaming (need complete blocks for output parsing in openai format) completion_stream=None,
print_verbose("makes dbrx streaming POST request") make_call=partial(
data["stream"] = stream make_sync_call,
try: client=None,
response = self.client.post( api_base=api_base,
api_base, headers=headers, data=json.dumps(data), stream=stream headers=headers, # type: ignore
) data=json.dumps(data),
response.raise_for_status() model=model,
completion_stream = response.iter_lines() messages=messages,
except httpx.HTTPStatusError as e: logging_obj=logging_obj,
raise DatabricksError( ),
status_code=e.response.status_code, message=response.text
)
except httpx.TimeoutException as e:
raise DatabricksError(
status_code=408, message="Timeout error occurred."
)
except Exception as e:
raise DatabricksError(status_code=408, message=str(e))
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model, model=model,
custom_llm_provider="databricks", custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj, logging_obj=logging_obj,
) )
return streaming_response
else: else:
try: try:
response = self.client.post( response = client.post(
api_base, headers=headers, data=json.dumps(data) api_base, headers=headers, data=json.dumps(data)
) )
response.raise_for_status() response.raise_for_status()
@ -667,7 +636,10 @@ class DatabricksChatCompletion(BaseLLM):
aembedding=None, aembedding=None,
) -> EmbeddingResponse: ) -> EmbeddingResponse:
api_base, headers = self._validate_environment( api_base, headers = self._validate_environment(
api_base=api_base, api_key=api_key, endpoint_type="embeddings" api_base=api_base,
api_key=api_key,
endpoint_type="embeddings",
custom_endpoint=False,
) )
model = model model = model
data = {"model": model, "input": input, **optional_params} data = {"model": model, "input": input, **optional_params}
@ -716,3 +688,128 @@ class DatabricksChatCompletion(BaseLLM):
) )
return litellm.EmbeddingResponse(**response_json) return litellm.EmbeddingResponse(**response_json)
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
if processed_chunk.choices[0].delta.content is not None: # type: ignore
text = processed_chunk.choices[0].delta.content # type: ignore
if (
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
is not None
):
tool_use = ChatCompletionToolCallChunk(
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.name,
arguments=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.arguments,
),
index=processed_chunk.choices[0].index,
)
if processed_chunk.choices[0].finish_reason is not None:
is_finished = True
finish_reason = processed_chunk.choices[0].finish_reason
if hasattr(processed_chunk, "usage"):
usage = processed_chunk.usage # type: ignore
return GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=0,
)
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
self.response_iterator = self.streaming_response
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if chunk == "[DONE]":
raise StopAsyncIteration
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -160,7 +160,7 @@ class MistralConfig:
optional_params["max_tokens"] = value optional_params["max_tokens"] = value
if param == "tools": if param == "tools":
optional_params["tools"] = value optional_params["tools"] = value
if param == "stream" and value == True: if param == "stream" and value is True:
optional_params["stream"] = value optional_params["stream"] = value
if param == "temperature": if param == "temperature":
optional_params["temperature"] = value optional_params["temperature"] = value

View file

@ -7,7 +7,7 @@ import time
import types import types
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
@ -108,14 +108,25 @@ class VertexAILlama3Config:
return optional_params return optional_params
class VertexAILlama3(BaseLLM): class VertexAIPartnerModels(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def create_vertex_llama3_url( def create_vertex_url(
self, vertex_location: str, vertex_project: str self,
vertex_location: str,
vertex_project: str,
partner: Literal["llama", "mistralai"],
stream: Optional[bool],
model: str,
) -> str: ) -> str:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi" if partner == "llama":
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
elif partner == "mistralai":
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
def completion( def completion(
self, self,
@ -141,6 +152,7 @@ class VertexAILlama3(BaseLLM):
import vertexai import vertexai
from google.cloud import aiplatform from google.cloud import aiplatform
from litellm.llms.databricks import DatabricksChatCompletion
from litellm.llms.openai import OpenAIChatCompletion from litellm.llms.openai import OpenAIChatCompletion
from litellm.llms.vertex_httpx import VertexLLM from litellm.llms.vertex_httpx import VertexLLM
except Exception: except Exception:
@ -165,7 +177,7 @@ class VertexAILlama3(BaseLLM):
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials, project_id=vertex_project
) )
openai_chat_completions = OpenAIChatCompletion() openai_like_chat_completions = DatabricksChatCompletion()
## Load Config ## Load Config
# config = litellm.VertexAILlama3.get_config() # config = litellm.VertexAILlama3.get_config()
@ -178,12 +190,23 @@ class VertexAILlama3(BaseLLM):
optional_params["stream"] = stream optional_params["stream"] = stream
api_base = self.create_vertex_llama3_url( if "llama" in model:
partner = "llama"
elif "mistral" in model:
partner = "mistralai"
optional_params["custom_endpoint"] = True
api_base = self.create_vertex_url(
vertex_location=vertex_location or "us-central1", vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id, vertex_project=vertex_project or project_id,
partner=partner, # type: ignore
stream=stream,
model=model,
) )
return openai_chat_completions.completion( model = model.split("@")[0]
return openai_like_chat_completions.completion(
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, api_base=api_base,
@ -198,6 +221,8 @@ class VertexAILlama3(BaseLLM):
logger_fn=logger_fn, logger_fn=logger_fn,
client=client, client=client,
timeout=timeout, timeout=timeout,
encoding=encoding,
custom_llm_provider="vertex_ai_beta",
) )
except Exception as e: except Exception as e:

View file

@ -121,7 +121,7 @@ from .llms.prompt_templates.factory import (
) )
from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_llama import VertexAILlama3 from .llms.vertex_ai_partner import VertexAIPartnerModels
from .llms.vertex_httpx import VertexLLM from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent from .types.llms.openai import HttpxBinaryResponseContent
@ -158,7 +158,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
vertex_llama_chat_completion = VertexAILlama3() vertex_partner_models_chat_completion = VertexAIPartnerModels()
watsonxai = IBMWatsonXAI() watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -1867,6 +1867,7 @@ def completion(
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client client=client, # pass AsyncOpenAI, OpenAI client
encoding=encoding, encoding=encoding,
custom_llm_provider="databricks",
) )
except Exception as e: except Exception as e:
## LOGGING - log the original exception returned ## LOGGING - log the original exception returned
@ -2068,8 +2069,8 @@ def completion(
timeout=timeout, timeout=timeout,
client=client, client=client,
) )
elif model.startswith("meta/"): elif model.startswith("meta/") or model.startswith("mistral"):
model_response = vertex_llama_chat_completion.completion( model_response = vertex_partner_models_chat_completion.completion(
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,

View file

@ -2028,6 +2028,16 @@
"mode": "chat", "mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
}, },
"vertex_ai/mistral-large@latest": {
"max_tokens": 8191,
"max_input_tokens": 128000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000009,
"litellm_provider": "vertex_ai-mistral_models",
"mode": "chat",
"supports_function_calling": true
},
"vertex_ai/mistral-large@2407": { "vertex_ai/mistral-large@2407": {
"max_tokens": 8191, "max_tokens": 8191,
"max_input_tokens": 128000, "max_input_tokens": 128000,

View file

@ -899,16 +899,18 @@ from litellm.tests.test_completion import response_format_tests
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", ["vertex_ai/meta/llama3-405b-instruct-maas"] "model",
[
"vertex_ai/mistral-large@2407",
"vertex_ai/meta/llama3-405b-instruct-maas",
], #
) # "vertex_ai", ) # "vertex_ai",
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sync_mode", "sync_mode",
[ [True, False],
True, ) #
],
) # False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_llama_3_httpx(model, sync_mode): async def test_partner_models_httpx(model, sync_mode):
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
@ -946,6 +948,57 @@ async def test_llama_3_httpx(model, sync_mode):
pytest.fail("An unexpected exception occurred - {}".format(str(e))) pytest.fail("An unexpected exception occurred - {}".format(str(e)))
@pytest.mark.parametrize(
"model",
[
"vertex_ai/mistral-large@2407",
"vertex_ai/meta/llama3-405b-instruct-maas",
], #
) # "vertex_ai",
@pytest.mark.parametrize(
"sync_mode",
[True, False], #
) #
@pytest.mark.asyncio
async def test_partner_models_httpx_streaming(model, sync_mode):
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
]
data = {"model": model, "messages": messages, "stream": True}
if sync_mode:
response = litellm.completion(**data)
for idx, chunk in enumerate(response):
streaming_format_tests(idx=idx, chunk=chunk)
else:
response = await litellm.acompletion(**data)
idx = 0
async for chunk in response:
streaming_format_tests(idx=idx, chunk=chunk)
idx += 1
print(f"response: {response}")
except litellm.RateLimitError:
pass
except Exception as e:
if "429 Quota exceeded" in str(e):
pass
else:
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
def vertex_httpx_mock_reject_prompt_post(*args, **kwargs): def vertex_httpx_mock_reject_prompt_post(*args, **kwargs):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 200 mock_response.status_code = 200

View file

@ -141,6 +141,21 @@ def test_vertex_ai_llama_3_optional_params():
assert "user" not in optional_params assert "user" not in optional_params
def test_vertex_ai_mistral_optional_params():
litellm.vertex_mistral_models = ["mistral-large@2407"]
litellm.drop_params = True
optional_params = get_optional_params(
model="mistral-large@2407",
user="John",
custom_llm_provider="vertex_ai",
max_tokens=10,
temperature=0.2,
)
assert "user" not in optional_params
assert "max_tokens" in optional_params
assert "temperature" in optional_params
def test_azure_gpt_optional_params_gpt_vision(): def test_azure_gpt_optional_params_gpt_vision():
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here # for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
optional_params = litellm.utils.get_optional_params( optional_params = litellm.utils.get_optional_params(

View file

@ -3104,6 +3104,15 @@ def get_optional_params(
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
) )
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_mistral_models:
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -4210,7 +4219,8 @@ def get_supported_openai_params(
if request_type == "chat_completion": if request_type == "chat_completion":
if model.startswith("meta/"): if model.startswith("meta/"):
return litellm.VertexAILlama3Config().get_supported_openai_params() return litellm.VertexAILlama3Config().get_supported_openai_params()
if model.startswith("mistral"):
return litellm.MistralConfig().get_supported_openai_params()
return litellm.VertexAIConfig().get_supported_openai_params() return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
@ -9264,11 +9274,20 @@ class CustomStreamWrapper:
try: try:
# return this for all models # return this for all models
completion_obj = {"content": ""} completion_obj = {"content": ""}
if self.custom_llm_provider and ( from litellm.types.utils import GenericStreamingChunk as GChunk
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers if (
isinstance(chunk, dict)
and all(
key in chunk for key in GChunk.__annotations__
) # check if chunk is a generic streaming chunk
) or (
self.custom_llm_provider
and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
)
): ):
from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None: if self.received_finish_reason is not None:
raise StopIteration raise StopIteration
@ -9634,22 +9653,6 @@ class CustomStreamWrapper:
completion_tokens=response_obj["usage"].completion_tokens, completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens, total_tokens=response_obj["usage"].total_tokens,
) )
elif self.custom_llm_provider == "databricks":
response_obj = litellm.DatabricksConfig()._chunk_parser(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens,
)
elif self.custom_llm_provider == "azure_text": elif self.custom_llm_provider == "azure_text":
response_obj = self.handle_azure_text_completion_chunk(chunk) response_obj = self.handle_azure_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]

View file

@ -2028,6 +2028,16 @@
"mode": "chat", "mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
}, },
"vertex_ai/mistral-large@latest": {
"max_tokens": 8191,
"max_input_tokens": 128000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000009,
"litellm_provider": "vertex_ai-mistral_models",
"mode": "chat",
"supports_function_calling": true
},
"vertex_ai/mistral-large@2407": { "vertex_ai/mistral-large@2407": {
"max_tokens": 8191, "max_tokens": 8191,
"max_input_tokens": 128000, "max_input_tokens": 128000,