Merge pull request #4925 from BerriAI/litellm_vertex_mistral

feat(vertex_ai_partner.py): Vertex AI Mistral Support
This commit is contained in:
Krish Dholakia 2024-07-27 21:51:26 -07:00 committed by GitHub
commit 1c50339580
10 changed files with 365 additions and 147 deletions

View file

@ -3104,6 +3104,15 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_mistral_models:
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -4210,7 +4219,8 @@ def get_supported_openai_params(
if request_type == "chat_completion":
if model.startswith("meta/"):
return litellm.VertexAILlama3Config().get_supported_openai_params()
if model.startswith("mistral"):
return litellm.MistralConfig().get_supported_openai_params()
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
@ -9264,11 +9274,20 @@ class CustomStreamWrapper:
try:
# return this for all models
completion_obj = {"content": ""}
if self.custom_llm_provider and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
from litellm.types.utils import GenericStreamingChunk as GChunk
if (
isinstance(chunk, dict)
and all(
key in chunk for key in GChunk.__annotations__
) # check if chunk is a generic streaming chunk
) or (
self.custom_llm_provider
and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
)
):
from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None:
raise StopIteration
@ -9634,22 +9653,6 @@ class CustomStreamWrapper:
completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens,
)
elif self.custom_llm_provider == "databricks":
response_obj = litellm.DatabricksConfig()._chunk_parser(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens,
)
elif self.custom_llm_provider == "azure_text":
response_obj = self.handle_azure_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]