diff --git a/litellm/__init__.py b/litellm/__init__.py index b6aacad1a..97a0a05ea 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -358,6 +358,7 @@ vertex_code_text_models: List = [] vertex_embedding_models: List = [] vertex_anthropic_models: List = [] vertex_llama3_models: List = [] +vertex_mistral_models: List = [] ai21_models: List = [] nlp_cloud_models: List = [] aleph_alpha_models: List = [] @@ -403,6 +404,9 @@ for key, value in model_cost.items(): elif value.get("litellm_provider") == "vertex_ai-llama_models": key = key.replace("vertex_ai/", "") vertex_llama3_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-mistral_models": + key = key.replace("vertex_ai/", "") + vertex_mistral_models.append(key) elif value.get("litellm_provider") == "ai21": ai21_models.append(key) elif value.get("litellm_provider") == "nlp_cloud": @@ -833,7 +837,7 @@ from .llms.petals import PetalsConfig from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig -from .llms.vertex_ai_llama import VertexAILlama3Config +from .llms.vertex_ai_partner import VertexAILlama3Config from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py index 88fa58abe..363b222fe 100644 --- a/litellm/llms/databricks.py +++ b/litellm/llms/databricks.py @@ -15,8 +15,14 @@ import requests # type: ignore import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.types.llms.databricks import GenericStreamingChunk -from litellm.types.utils import ProviderField +from litellm.types.llms.openai import ( + ChatCompletionDeltaChunk, + ChatCompletionResponseMessage, + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, + ChatCompletionUsageBlock, +) +from litellm.types.utils import GenericStreamingChunk, ProviderField from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage from .base import BaseLLM @@ -114,71 +120,6 @@ class DatabricksConfig: optional_params["stop"] = value return optional_params - def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk: - try: - text = "" - is_finished = False - finish_reason = None - logprobs = None - usage = None - original_chunk = None # this is used for function/tool calling - chunk_data = chunk_data.replace("data:", "") - chunk_data = chunk_data.strip() - if len(chunk_data) == 0 or chunk_data == "[DONE]": - return { - "text": "", - "is_finished": is_finished, - "finish_reason": finish_reason, - } - chunk_data_dict = json.loads(chunk_data) - str_line = litellm.ModelResponse(**chunk_data_dict, stream=True) - - if len(str_line.choices) > 0: - if ( - str_line.choices[0].delta is not None # type: ignore - and str_line.choices[0].delta.content is not None # type: ignore - ): - text = str_line.choices[0].delta.content # type: ignore - else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai - original_chunk = str_line - if str_line.choices[0].finish_reason: - is_finished = True - finish_reason = str_line.choices[0].finish_reason - if finish_reason == "content_filter": - if hasattr(str_line.choices[0], "content_filter_result"): - error_message = json.dumps( - str_line.choices[0].content_filter_result # type: ignore - ) - else: - error_message = "Azure Response={}".format( - str(dict(str_line)) - ) - raise litellm.AzureOpenAIError( - status_code=400, message=error_message - ) - - # checking for logprobs - if ( - hasattr(str_line.choices[0], "logprobs") - and str_line.choices[0].logprobs is not None - ): - logprobs = str_line.choices[0].logprobs - else: - logprobs = None - - usage = getattr(str_line, "usage", None) - - return GenericStreamingChunk( - text=text, - is_finished=is_finished, - finish_reason=finish_reason, - logprobs=logprobs, - original_chunk=original_chunk, - usage=usage, - ) - except Exception as e: - raise e - class DatabricksEmbeddingConfig: """ @@ -236,7 +177,9 @@ async def make_call( if response.status_code != 200: raise DatabricksError(status_code=response.status_code, message=response.text) - completion_stream = response.aiter_lines() + completion_stream = ModelResponseIterator( + streaming_response=response.aiter_lines(), sync_stream=False + ) # LOGGING logging_obj.post_call( input=messages, @@ -248,6 +191,38 @@ async def make_call( return completion_stream +def make_sync_call( + client: Optional[HTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = HTTPHandler() # Create a new client if none provided + + response = client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise DatabricksError(status_code=response.status_code, message=response.read()) + + completion_stream = ModelResponseIterator( + streaming_response=response.iter_lines(), sync_stream=True + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + class DatabricksChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() @@ -259,6 +234,7 @@ class DatabricksChatCompletion(BaseLLM): api_key: Optional[str], api_base: Optional[str], endpoint_type: Literal["chat_completions", "embeddings"], + custom_endpoint: Optional[bool], ) -> Tuple[str, dict]: if api_key is None: raise DatabricksError( @@ -277,9 +253,9 @@ class DatabricksChatCompletion(BaseLLM): "Content-Type": "application/json", } - if endpoint_type == "chat_completions": + if endpoint_type == "chat_completions" and custom_endpoint is not True: api_base = "{}/chat/completions".format(api_base) - elif endpoint_type == "embeddings": + elif endpoint_type == "embeddings" and custom_endpoint is not True: api_base = "{}/embeddings".format(api_base) return api_base, headers @@ -368,6 +344,7 @@ class DatabricksChatCompletion(BaseLLM): self, model: str, messages: list, + custom_llm_provider: str, api_base: str, custom_prompt_dict: dict, model_response: ModelResponse, @@ -397,7 +374,7 @@ class DatabricksChatCompletion(BaseLLM): logging_obj=logging_obj, ), model=model, - custom_llm_provider="databricks", + custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) return streamwrapper @@ -450,6 +427,7 @@ class DatabricksChatCompletion(BaseLLM): model: str, messages: list, api_base: str, + custom_llm_provider: str, custom_prompt_dict: dict, model_response: ModelResponse, print_verbose: Callable, @@ -464,8 +442,12 @@ class DatabricksChatCompletion(BaseLLM): timeout: Optional[Union[float, httpx.Timeout]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): + custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None) api_base, headers = self._validate_environment( - api_base=api_base, api_key=api_key, endpoint_type="chat_completions" + api_base=api_base, + api_key=api_key, + endpoint_type="chat_completions", + custom_endpoint=custom_endpoint, ) ## Load Config config = litellm.DatabricksConfig().get_config() @@ -475,7 +457,8 @@ class DatabricksChatCompletion(BaseLLM): ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v - stream = optional_params.pop("stream", None) + stream: bool = optional_params.pop("stream", None) or False + optional_params["stream"] = stream data = { "model": model, @@ -518,6 +501,7 @@ class DatabricksChatCompletion(BaseLLM): logger_fn=logger_fn, headers=headers, client=client, + custom_llm_provider=custom_llm_provider, ) else: return self.acompletion_function( @@ -539,44 +523,29 @@ class DatabricksChatCompletion(BaseLLM): timeout=timeout, ) else: - if client is None or isinstance(client, AsyncHTTPHandler): - self.client = HTTPHandler(timeout=timeout) # type: ignore - else: - self.client = client + if client is None or not isinstance(client, HTTPHandler): + client = HTTPHandler(timeout=timeout) # type: ignore ## COMPLETION CALL - if ( - stream is not None and stream == True - ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) - print_verbose("makes dbrx streaming POST request") - data["stream"] = stream - try: - response = self.client.post( - api_base, headers=headers, data=json.dumps(data), stream=stream - ) - response.raise_for_status() - completion_stream = response.iter_lines() - except httpx.HTTPStatusError as e: - raise DatabricksError( - status_code=e.response.status_code, message=response.text - ) - except httpx.TimeoutException as e: - raise DatabricksError( - status_code=408, message="Timeout error occurred." - ) - except Exception as e: - raise DatabricksError(status_code=408, message=str(e)) - - streaming_response = CustomStreamWrapper( - completion_stream=completion_stream, + if stream is True: + return CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_sync_call, + client=None, + api_base=api_base, + headers=headers, # type: ignore + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + ), model=model, - custom_llm_provider="databricks", + custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) - return streaming_response - else: try: - response = self.client.post( + response = client.post( api_base, headers=headers, data=json.dumps(data) ) response.raise_for_status() @@ -667,7 +636,10 @@ class DatabricksChatCompletion(BaseLLM): aembedding=None, ) -> EmbeddingResponse: api_base, headers = self._validate_environment( - api_base=api_base, api_key=api_key, endpoint_type="embeddings" + api_base=api_base, + api_key=api_key, + endpoint_type="embeddings", + custom_endpoint=False, ) model = model data = {"model": model, "input": input, **optional_params} @@ -716,3 +688,128 @@ class DatabricksChatCompletion(BaseLLM): ) return litellm.EmbeddingResponse(**response_json) + + +class ModelResponseIterator: + def __init__(self, streaming_response, sync_stream: bool): + self.streaming_response = streaming_response + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore + + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + + if processed_chunk.choices[0].delta.content is not None: # type: ignore + text = processed_chunk.choices[0].delta.content # type: ignore + + if ( + processed_chunk.choices[0].delta.tool_calls is not None # type: ignore + and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore + and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore + and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore + is not None + ): + tool_use = ChatCompletionToolCallChunk( + id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore + type="function", + function=ChatCompletionToolCallFunctionChunk( + name=processed_chunk.choices[0] + .delta.tool_calls[0] # type: ignore + .function.name, + arguments=processed_chunk.choices[0] + .delta.tool_calls[0] # type: ignore + .function.arguments, + ), + index=processed_chunk.choices[0].index, + ) + + if processed_chunk.choices[0].finish_reason is not None: + is_finished = True + finish_reason = processed_chunk.choices[0].finish_reason + + if hasattr(processed_chunk, "usage"): + usage = processed_chunk.usage # type: ignore + + return GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=0, + ) + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + # Sync iterator + def __iter__(self): + self.response_iterator = self.streaming_response + return self + + def __next__(self): + try: + chunk = self.response_iterator.__next__() + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + chunk = chunk.replace("data:", "") + chunk = chunk.strip() + if len(chunk) > 0: + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + chunk = chunk.replace("data:", "") + chunk = chunk.strip() + if chunk == "[DONE]": + raise StopAsyncIteration + if len(chunk) > 0: + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 94000233c..afd49ab14 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -160,7 +160,7 @@ class MistralConfig: optional_params["max_tokens"] = value if param == "tools": optional_params["tools"] = value - if param == "stream" and value == True: + if param == "stream" and value is True: optional_params["stream"] = value if param == "temperature": optional_params["temperature"] = value diff --git a/litellm/llms/vertex_ai_llama.py b/litellm/llms/vertex_ai_partner.py similarity index 78% rename from litellm/llms/vertex_ai_llama.py rename to litellm/llms/vertex_ai_partner.py index cc4786c4b..eb24c4d26 100644 --- a/litellm/llms/vertex_ai_llama.py +++ b/litellm/llms/vertex_ai_partner.py @@ -7,7 +7,7 @@ import time import types import uuid from enum import Enum -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Literal, Optional, Tuple, Union import httpx # type: ignore import requests # type: ignore @@ -108,14 +108,25 @@ class VertexAILlama3Config: return optional_params -class VertexAILlama3(BaseLLM): +class VertexAIPartnerModels(BaseLLM): def __init__(self) -> None: pass - def create_vertex_llama3_url( - self, vertex_location: str, vertex_project: str + def create_vertex_url( + self, + vertex_location: str, + vertex_project: str, + partner: Literal["llama", "mistralai"], + stream: Optional[bool], + model: str, ) -> str: - return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi" + if partner == "llama": + return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi" + elif partner == "mistralai": + if stream: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict" + else: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict" def completion( self, @@ -141,6 +152,7 @@ class VertexAILlama3(BaseLLM): import vertexai from google.cloud import aiplatform + from litellm.llms.databricks import DatabricksChatCompletion from litellm.llms.openai import OpenAIChatCompletion from litellm.llms.vertex_httpx import VertexLLM except Exception: @@ -165,7 +177,7 @@ class VertexAILlama3(BaseLLM): credentials=vertex_credentials, project_id=vertex_project ) - openai_chat_completions = OpenAIChatCompletion() + openai_like_chat_completions = DatabricksChatCompletion() ## Load Config # config = litellm.VertexAILlama3.get_config() @@ -178,12 +190,23 @@ class VertexAILlama3(BaseLLM): optional_params["stream"] = stream - api_base = self.create_vertex_llama3_url( + if "llama" in model: + partner = "llama" + elif "mistral" in model: + partner = "mistralai" + optional_params["custom_endpoint"] = True + + api_base = self.create_vertex_url( vertex_location=vertex_location or "us-central1", vertex_project=vertex_project or project_id, + partner=partner, # type: ignore + stream=stream, + model=model, ) - return openai_chat_completions.completion( + model = model.split("@")[0] + + return openai_like_chat_completions.completion( model=model, messages=messages, api_base=api_base, @@ -198,6 +221,8 @@ class VertexAILlama3(BaseLLM): logger_fn=logger_fn, client=client, timeout=timeout, + encoding=encoding, + custom_llm_provider="vertex_ai_beta", ) except Exception as e: diff --git a/litellm/main.py b/litellm/main.py index 134617ba0..4a0d1251e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -121,7 +121,7 @@ from .llms.prompt_templates.factory import ( ) from .llms.text_completion_codestral import CodestralTextCompletion from .llms.triton import TritonChatCompletion -from .llms.vertex_ai_llama import VertexAILlama3 +from .llms.vertex_ai_partner import VertexAIPartnerModels from .llms.vertex_httpx import VertexLLM from .llms.watsonx import IBMWatsonXAI from .types.llms.openai import HttpxBinaryResponseContent @@ -158,7 +158,7 @@ triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() -vertex_llama_chat_completion = VertexAILlama3() +vertex_partner_models_chat_completion = VertexAIPartnerModels() watsonxai = IBMWatsonXAI() ####### COMPLETION ENDPOINTS ################ @@ -1867,6 +1867,7 @@ def completion( custom_prompt_dict=custom_prompt_dict, client=client, # pass AsyncOpenAI, OpenAI client encoding=encoding, + custom_llm_provider="databricks", ) except Exception as e: ## LOGGING - log the original exception returned @@ -2068,8 +2069,8 @@ def completion( timeout=timeout, client=client, ) - elif model.startswith("meta/"): - model_response = vertex_llama_chat_completion.completion( + elif model.startswith("meta/") or model.startswith("mistral"): + model_response = vertex_partner_models_chat_completion.completion( model=model, messages=messages, model_response=model_response, diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index a6d66750c..d20e5681c 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -2028,6 +2028,16 @@ "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" }, + "vertex_ai/mistral-large@latest": { + "max_tokens": 8191, + "max_input_tokens": 128000, + "max_output_tokens": 8191, + "input_cost_per_token": 0.000003, + "output_cost_per_token": 0.000009, + "litellm_provider": "vertex_ai-mistral_models", + "mode": "chat", + "supports_function_calling": true + }, "vertex_ai/mistral-large@2407": { "max_tokens": 8191, "max_input_tokens": 128000, diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index bebe5d031..5419c25ff 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -899,16 +899,18 @@ from litellm.tests.test_completion import response_format_tests @pytest.mark.parametrize( - "model", ["vertex_ai/meta/llama3-405b-instruct-maas"] + "model", + [ + "vertex_ai/mistral-large@2407", + "vertex_ai/meta/llama3-405b-instruct-maas", + ], # ) # "vertex_ai", @pytest.mark.parametrize( "sync_mode", - [ - True, - ], -) # False + [True, False], +) # @pytest.mark.asyncio -async def test_llama_3_httpx(model, sync_mode): +async def test_partner_models_httpx(model, sync_mode): try: load_vertex_ai_credentials() litellm.set_verbose = True @@ -946,6 +948,57 @@ async def test_llama_3_httpx(model, sync_mode): pytest.fail("An unexpected exception occurred - {}".format(str(e))) +@pytest.mark.parametrize( + "model", + [ + "vertex_ai/mistral-large@2407", + "vertex_ai/meta/llama3-405b-instruct-maas", + ], # +) # "vertex_ai", +@pytest.mark.parametrize( + "sync_mode", + [True, False], # +) # +@pytest.mark.asyncio +async def test_partner_models_httpx_streaming(model, sync_mode): + try: + load_vertex_ai_credentials() + litellm.set_verbose = True + + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + # User asks for their name and weather in San Francisco + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + ] + + data = {"model": model, "messages": messages, "stream": True} + if sync_mode: + response = litellm.completion(**data) + for idx, chunk in enumerate(response): + streaming_format_tests(idx=idx, chunk=chunk) + else: + response = await litellm.acompletion(**data) + idx = 0 + async for chunk in response: + streaming_format_tests(idx=idx, chunk=chunk) + idx += 1 + + print(f"response: {response}") + except litellm.RateLimitError: + pass + except Exception as e: + if "429 Quota exceeded" in str(e): + pass + else: + pytest.fail("An unexpected exception occurred - {}".format(str(e))) + + def vertex_httpx_mock_reject_prompt_post(*args, **kwargs): mock_response = MagicMock() mock_response.status_code = 200 diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index b8011960e..83ac855a8 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -141,6 +141,21 @@ def test_vertex_ai_llama_3_optional_params(): assert "user" not in optional_params +def test_vertex_ai_mistral_optional_params(): + litellm.vertex_mistral_models = ["mistral-large@2407"] + litellm.drop_params = True + optional_params = get_optional_params( + model="mistral-large@2407", + user="John", + custom_llm_provider="vertex_ai", + max_tokens=10, + temperature=0.2, + ) + assert "user" not in optional_params + assert "max_tokens" in optional_params + assert "temperature" in optional_params + + def test_azure_gpt_optional_params_gpt_vision(): # for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here optional_params = litellm.utils.get_optional_params( diff --git a/litellm/utils.py b/litellm/utils.py index 4e3a4e60a..ec1370c30 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3104,6 +3104,15 @@ def get_optional_params( non_default_params=non_default_params, optional_params=optional_params, ) + elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_mistral_models: + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.MistralConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) elif custom_llm_provider == "sagemaker": ## check if unsupported param passed in supported_params = get_supported_openai_params( @@ -4210,7 +4219,8 @@ def get_supported_openai_params( if request_type == "chat_completion": if model.startswith("meta/"): return litellm.VertexAILlama3Config().get_supported_openai_params() - + if model.startswith("mistral"): + return litellm.MistralConfig().get_supported_openai_params() return litellm.VertexAIConfig().get_supported_openai_params() elif request_type == "embeddings": return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() @@ -9264,11 +9274,20 @@ class CustomStreamWrapper: try: # return this for all models completion_obj = {"content": ""} - if self.custom_llm_provider and ( - self.custom_llm_provider == "anthropic" - or self.custom_llm_provider in litellm._custom_providers + from litellm.types.utils import GenericStreamingChunk as GChunk + + if ( + isinstance(chunk, dict) + and all( + key in chunk for key in GChunk.__annotations__ + ) # check if chunk is a generic streaming chunk + ) or ( + self.custom_llm_provider + and ( + self.custom_llm_provider == "anthropic" + or self.custom_llm_provider in litellm._custom_providers + ) ): - from litellm.types.utils import GenericStreamingChunk as GChunk if self.received_finish_reason is not None: raise StopIteration @@ -9634,22 +9653,6 @@ class CustomStreamWrapper: completion_tokens=response_obj["usage"].completion_tokens, total_tokens=response_obj["usage"].total_tokens, ) - elif self.custom_llm_provider == "databricks": - response_obj = litellm.DatabricksConfig()._chunk_parser(chunk) - completion_obj["content"] = response_obj["text"] - print_verbose(f"completion obj content: {completion_obj['content']}") - if response_obj["is_finished"]: - self.received_finish_reason = response_obj["finish_reason"] - if ( - self.stream_options - and self.stream_options.get("include_usage", False) == True - and response_obj["usage"] is not None - ): - model_response.usage = litellm.Usage( - prompt_tokens=response_obj["usage"].prompt_tokens, - completion_tokens=response_obj["usage"].completion_tokens, - total_tokens=response_obj["usage"].total_tokens, - ) elif self.custom_llm_provider == "azure_text": response_obj = self.handle_azure_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index a6d66750c..d20e5681c 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -2028,6 +2028,16 @@ "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" }, + "vertex_ai/mistral-large@latest": { + "max_tokens": 8191, + "max_input_tokens": 128000, + "max_output_tokens": 8191, + "input_cost_per_token": 0.000003, + "output_cost_per_token": 0.000009, + "litellm_provider": "vertex_ai-mistral_models", + "mode": "chat", + "supports_function_calling": true + }, "vertex_ai/mistral-large@2407": { "max_tokens": 8191, "max_input_tokens": 128000,