feat(vertex_ai_partner.py): add vertex ai codestral FIM support

Closes https://github.com/BerriAI/litellm/issues/4984
This commit is contained in:
Krrish Dholakia 2024-08-01 17:10:27 -07:00
parent 246b3227a9
commit 010d5ed81d
5 changed files with 96 additions and 32 deletions

View file

@ -140,10 +140,10 @@ class VertexAIPartnerModels(BaseLLM):
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
client=None,
@ -154,6 +154,7 @@ class VertexAIPartnerModels(BaseLLM):
from litellm.llms.databricks import DatabricksChatCompletion
from litellm.llms.openai import OpenAIChatCompletion
from litellm.llms.text_completion_codestral import CodestralTextCompletion
from litellm.llms.vertex_httpx import VertexLLM
except Exception:
@ -178,12 +179,7 @@ class VertexAIPartnerModels(BaseLLM):
)
openai_like_chat_completions = DatabricksChatCompletion()
## Load Config
# config = litellm.VertexAILlama3.get_config()
# for k, v in config.items():
# if k not in optional_params:
# optional_params[k] = v
codestral_fim_completions = CodestralTextCompletion()
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
@ -206,6 +202,28 @@ class VertexAIPartnerModels(BaseLLM):
model = model.split("@")[0]
if "codestral" in model and litellm_params.get("text_completion") is True:
optional_params["model"] = model
text_completion_model_response = litellm.TextCompletionResponse(
stream=stream
)
return codestral_fim_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=text_completion_model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
encoding=encoding,
)
return openai_like_chat_completions.completion(
model=model,
messages=messages,