Merge pull request #4925 from BerriAI/litellm_vertex_mistral

feat(vertex_ai_partner.py): Vertex AI Mistral Support
This commit is contained in:
Krish Dholakia 2024-07-27 21:51:26 -07:00 committed by GitHub
commit e3a94ac013
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 365 additions and 147 deletions

View file

@ -358,6 +358,7 @@ vertex_code_text_models: List = []
vertex_embedding_models: List = []
vertex_anthropic_models: List = []
vertex_llama3_models: List = []
vertex_mistral_models: List = []
ai21_models: List = []
nlp_cloud_models: List = []
aleph_alpha_models: List = []
@ -403,6 +404,9 @@ for key, value in model_cost.items():
elif value.get("litellm_provider") == "vertex_ai-llama_models":
key = key.replace("vertex_ai/", "")
vertex_llama3_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-mistral_models":
key = key.replace("vertex_ai/", "")
vertex_mistral_models.append(key)
elif value.get("litellm_provider") == "ai21":
ai21_models.append(key)
elif value.get("litellm_provider") == "nlp_cloud":
@ -833,7 +837,7 @@ from .llms.petals import PetalsConfig
from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
from .llms.vertex_ai_llama import VertexAILlama3Config
from .llms.vertex_ai_partner import VertexAILlama3Config
from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig

View file

@ -15,8 +15,14 @@ import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.types.utils import ProviderField
from litellm.types.llms.openai import (
ChatCompletionDeltaChunk,
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk, ProviderField
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM
@ -114,71 +120,6 @@ class DatabricksConfig:
optional_params["stop"] = value
return optional_params
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
try:
text = ""
is_finished = False
finish_reason = None
logprobs = None
usage = None
original_chunk = None # this is used for function/tool calling
chunk_data = chunk_data.replace("data:", "")
chunk_data = chunk_data.strip()
if len(chunk_data) == 0 or chunk_data == "[DONE]":
return {
"text": "",
"is_finished": is_finished,
"finish_reason": finish_reason,
}
chunk_data_dict = json.loads(chunk_data)
str_line = litellm.ModelResponse(**chunk_data_dict, stream=True)
if len(str_line.choices) > 0:
if (
str_line.choices[0].delta is not None # type: ignore
and str_line.choices[0].delta.content is not None # type: ignore
):
text = str_line.choices[0].delta.content # type: ignore
else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai
original_chunk = str_line
if str_line.choices[0].finish_reason:
is_finished = True
finish_reason = str_line.choices[0].finish_reason
if finish_reason == "content_filter":
if hasattr(str_line.choices[0], "content_filter_result"):
error_message = json.dumps(
str_line.choices[0].content_filter_result # type: ignore
)
else:
error_message = "Azure Response={}".format(
str(dict(str_line))
)
raise litellm.AzureOpenAIError(
status_code=400, message=error_message
)
# checking for logprobs
if (
hasattr(str_line.choices[0], "logprobs")
and str_line.choices[0].logprobs is not None
):
logprobs = str_line.choices[0].logprobs
else:
logprobs = None
usage = getattr(str_line, "usage", None)
return GenericStreamingChunk(
text=text,
is_finished=is_finished,
finish_reason=finish_reason,
logprobs=logprobs,
original_chunk=original_chunk,
usage=usage,
)
except Exception as e:
raise e
class DatabricksEmbeddingConfig:
"""
@ -236,7 +177,9 @@ async def make_call(
if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.text)
completion_stream = response.aiter_lines()
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
# LOGGING
logging_obj.post_call(
input=messages,
@ -248,6 +191,38 @@ async def make_call(
return completion_stream
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream
class DatabricksChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
@ -259,6 +234,7 @@ class DatabricksChatCompletion(BaseLLM):
api_key: Optional[str],
api_base: Optional[str],
endpoint_type: Literal["chat_completions", "embeddings"],
custom_endpoint: Optional[bool],
) -> Tuple[str, dict]:
if api_key is None:
raise DatabricksError(
@ -277,9 +253,9 @@ class DatabricksChatCompletion(BaseLLM):
"Content-Type": "application/json",
}
if endpoint_type == "chat_completions":
if endpoint_type == "chat_completions" and custom_endpoint is not True:
api_base = "{}/chat/completions".format(api_base)
elif endpoint_type == "embeddings":
elif endpoint_type == "embeddings" and custom_endpoint is not True:
api_base = "{}/embeddings".format(api_base)
return api_base, headers
@ -368,6 +344,7 @@ class DatabricksChatCompletion(BaseLLM):
self,
model: str,
messages: list,
custom_llm_provider: str,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
@ -397,7 +374,7 @@ class DatabricksChatCompletion(BaseLLM):
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="databricks",
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
return streamwrapper
@ -450,6 +427,7 @@ class DatabricksChatCompletion(BaseLLM):
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
@ -464,8 +442,12 @@ class DatabricksChatCompletion(BaseLLM):
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None)
api_base, headers = self._validate_environment(
api_base=api_base, api_key=api_key, endpoint_type="chat_completions"
api_base=api_base,
api_key=api_key,
endpoint_type="chat_completions",
custom_endpoint=custom_endpoint,
)
## Load Config
config = litellm.DatabricksConfig().get_config()
@ -475,7 +457,8 @@ class DatabricksChatCompletion(BaseLLM):
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
stream = optional_params.pop("stream", None)
stream: bool = optional_params.pop("stream", None) or False
optional_params["stream"] = stream
data = {
"model": model,
@ -518,6 +501,7 @@ class DatabricksChatCompletion(BaseLLM):
logger_fn=logger_fn,
headers=headers,
client=client,
custom_llm_provider=custom_llm_provider,
)
else:
return self.acompletion_function(
@ -539,44 +523,29 @@ class DatabricksChatCompletion(BaseLLM):
timeout=timeout,
)
else:
if client is None or isinstance(client, AsyncHTTPHandler):
self.client = HTTPHandler(timeout=timeout) # type: ignore
else:
self.client = client
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(timeout=timeout) # type: ignore
## COMPLETION CALL
if (
stream is not None and stream == True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes dbrx streaming POST request")
data["stream"] = stream
try:
response = self.client.post(
api_base, headers=headers, data=json.dumps(data), stream=stream
)
response.raise_for_status()
completion_stream = response.iter_lines()
except httpx.HTTPStatusError as e:
raise DatabricksError(
status_code=e.response.status_code, message=response.text
)
except httpx.TimeoutException as e:
raise DatabricksError(
status_code=408, message="Timeout error occurred."
)
except Exception as e:
raise DatabricksError(status_code=408, message=str(e))
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
if stream is True:
return CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=None,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="databricks",
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
return streaming_response
else:
try:
response = self.client.post(
response = client.post(
api_base, headers=headers, data=json.dumps(data)
)
response.raise_for_status()
@ -667,7 +636,10 @@ class DatabricksChatCompletion(BaseLLM):
aembedding=None,
) -> EmbeddingResponse:
api_base, headers = self._validate_environment(
api_base=api_base, api_key=api_key, endpoint_type="embeddings"
api_base=api_base,
api_key=api_key,
endpoint_type="embeddings",
custom_endpoint=False,
)
model = model
data = {"model": model, "input": input, **optional_params}
@ -716,3 +688,128 @@ class DatabricksChatCompletion(BaseLLM):
)
return litellm.EmbeddingResponse(**response_json)
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
if processed_chunk.choices[0].delta.content is not None: # type: ignore
text = processed_chunk.choices[0].delta.content # type: ignore
if (
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
is not None
):
tool_use = ChatCompletionToolCallChunk(
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.name,
arguments=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.arguments,
),
index=processed_chunk.choices[0].index,
)
if processed_chunk.choices[0].finish_reason is not None:
is_finished = True
finish_reason = processed_chunk.choices[0].finish_reason
if hasattr(processed_chunk, "usage"):
usage = processed_chunk.usage # type: ignore
return GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=0,
)
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
self.response_iterator = self.streaming_response
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if chunk == "[DONE]":
raise StopAsyncIteration
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -160,7 +160,7 @@ class MistralConfig:
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value == True:
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value

View file

@ -7,7 +7,7 @@ import time
import types
import uuid
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
@ -108,14 +108,25 @@ class VertexAILlama3Config:
return optional_params
class VertexAILlama3(BaseLLM):
class VertexAIPartnerModels(BaseLLM):
def __init__(self) -> None:
pass
def create_vertex_llama3_url(
self, vertex_location: str, vertex_project: str
def create_vertex_url(
self,
vertex_location: str,
vertex_project: str,
partner: Literal["llama", "mistralai"],
stream: Optional[bool],
model: str,
) -> str:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
if partner == "llama":
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
elif partner == "mistralai":
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
def completion(
self,
@ -141,6 +152,7 @@ class VertexAILlama3(BaseLLM):
import vertexai
from google.cloud import aiplatform
from litellm.llms.databricks import DatabricksChatCompletion
from litellm.llms.openai import OpenAIChatCompletion
from litellm.llms.vertex_httpx import VertexLLM
except Exception:
@ -165,7 +177,7 @@ class VertexAILlama3(BaseLLM):
credentials=vertex_credentials, project_id=vertex_project
)
openai_chat_completions = OpenAIChatCompletion()
openai_like_chat_completions = DatabricksChatCompletion()
## Load Config
# config = litellm.VertexAILlama3.get_config()
@ -178,12 +190,23 @@ class VertexAILlama3(BaseLLM):
optional_params["stream"] = stream
api_base = self.create_vertex_llama3_url(
if "llama" in model:
partner = "llama"
elif "mistral" in model:
partner = "mistralai"
optional_params["custom_endpoint"] = True
api_base = self.create_vertex_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
partner=partner, # type: ignore
stream=stream,
model=model,
)
return openai_chat_completions.completion(
model = model.split("@")[0]
return openai_like_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
@ -198,6 +221,8 @@ class VertexAILlama3(BaseLLM):
logger_fn=logger_fn,
client=client,
timeout=timeout,
encoding=encoding,
custom_llm_provider="vertex_ai_beta",
)
except Exception as e:

View file

@ -121,7 +121,7 @@ from .llms.prompt_templates.factory import (
)
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_llama import VertexAILlama3
from .llms.vertex_ai_partner import VertexAIPartnerModels
from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
@ -158,7 +158,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
vertex_llama_chat_completion = VertexAILlama3()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################
@ -1867,6 +1867,7 @@ def completion(
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
encoding=encoding,
custom_llm_provider="databricks",
)
except Exception as e:
## LOGGING - log the original exception returned
@ -2068,8 +2069,8 @@ def completion(
timeout=timeout,
client=client,
)
elif model.startswith("meta/"):
model_response = vertex_llama_chat_completion.completion(
elif model.startswith("meta/") or model.startswith("mistral"):
model_response = vertex_partner_models_chat_completion.completion(
model=model,
messages=messages,
model_response=model_response,

View file

@ -2028,6 +2028,16 @@
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/mistral-large@latest": {
"max_tokens": 8191,
"max_input_tokens": 128000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000009,
"litellm_provider": "vertex_ai-mistral_models",
"mode": "chat",
"supports_function_calling": true
},
"vertex_ai/mistral-large@2407": {
"max_tokens": 8191,
"max_input_tokens": 128000,

View file

@ -899,16 +899,18 @@ from litellm.tests.test_completion import response_format_tests
@pytest.mark.parametrize(
"model", ["vertex_ai/meta/llama3-405b-instruct-maas"]
"model",
[
"vertex_ai/mistral-large@2407",
"vertex_ai/meta/llama3-405b-instruct-maas",
], #
) # "vertex_ai",
@pytest.mark.parametrize(
"sync_mode",
[
True,
],
) # False
[True, False],
) #
@pytest.mark.asyncio
async def test_llama_3_httpx(model, sync_mode):
async def test_partner_models_httpx(model, sync_mode):
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
@ -946,6 +948,57 @@ async def test_llama_3_httpx(model, sync_mode):
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
@pytest.mark.parametrize(
"model",
[
"vertex_ai/mistral-large@2407",
"vertex_ai/meta/llama3-405b-instruct-maas",
], #
) # "vertex_ai",
@pytest.mark.parametrize(
"sync_mode",
[True, False], #
) #
@pytest.mark.asyncio
async def test_partner_models_httpx_streaming(model, sync_mode):
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
]
data = {"model": model, "messages": messages, "stream": True}
if sync_mode:
response = litellm.completion(**data)
for idx, chunk in enumerate(response):
streaming_format_tests(idx=idx, chunk=chunk)
else:
response = await litellm.acompletion(**data)
idx = 0
async for chunk in response:
streaming_format_tests(idx=idx, chunk=chunk)
idx += 1
print(f"response: {response}")
except litellm.RateLimitError:
pass
except Exception as e:
if "429 Quota exceeded" in str(e):
pass
else:
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
def vertex_httpx_mock_reject_prompt_post(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200

View file

@ -141,6 +141,21 @@ def test_vertex_ai_llama_3_optional_params():
assert "user" not in optional_params
def test_vertex_ai_mistral_optional_params():
litellm.vertex_mistral_models = ["mistral-large@2407"]
litellm.drop_params = True
optional_params = get_optional_params(
model="mistral-large@2407",
user="John",
custom_llm_provider="vertex_ai",
max_tokens=10,
temperature=0.2,
)
assert "user" not in optional_params
assert "max_tokens" in optional_params
assert "temperature" in optional_params
def test_azure_gpt_optional_params_gpt_vision():
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
optional_params = litellm.utils.get_optional_params(

View file

@ -3104,6 +3104,15 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_mistral_models:
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -4210,7 +4219,8 @@ def get_supported_openai_params(
if request_type == "chat_completion":
if model.startswith("meta/"):
return litellm.VertexAILlama3Config().get_supported_openai_params()
if model.startswith("mistral"):
return litellm.MistralConfig().get_supported_openai_params()
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
@ -9264,11 +9274,20 @@ class CustomStreamWrapper:
try:
# return this for all models
completion_obj = {"content": ""}
if self.custom_llm_provider and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
from litellm.types.utils import GenericStreamingChunk as GChunk
if (
isinstance(chunk, dict)
and all(
key in chunk for key in GChunk.__annotations__
) # check if chunk is a generic streaming chunk
) or (
self.custom_llm_provider
and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
)
):
from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None:
raise StopIteration
@ -9634,22 +9653,6 @@ class CustomStreamWrapper:
completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens,
)
elif self.custom_llm_provider == "databricks":
response_obj = litellm.DatabricksConfig()._chunk_parser(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens,
)
elif self.custom_llm_provider == "azure_text":
response_obj = self.handle_azure_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]

View file

@ -2028,6 +2028,16 @@
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/mistral-large@latest": {
"max_tokens": 8191,
"max_input_tokens": 128000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000009,
"litellm_provider": "vertex_ai-mistral_models",
"mode": "chat",
"supports_function_calling": true
},
"vertex_ai/mistral-large@2407": {
"max_tokens": 8191,
"max_input_tokens": 128000,