feat(vertex_ai_partner.py): add vertex ai codestral FIM support

Closes https://github.com/BerriAI/litellm/issues/4984
This commit is contained in:
Krrish Dholakia 2024-08-01 17:10:27 -07:00
parent 2121738137
commit cb9b19e887
5 changed files with 96 additions and 32 deletions

View file

@ -1,28 +1,33 @@
# What is this? # What is this?
## Controller file for TextCompletionCodestral Integration - https://codestral.com/ ## Controller file for TextCompletionCodestral Integration - https://codestral.com/
from functools import partial import copy
import os, types
import traceback
import json import json
from enum import Enum import os
import requests, copy # type: ignore
import time import time
from typing import Callable, Optional, List, Literal, Union import traceback
import types
from enum import Enum
from functools import partial
from typing import Callable, List, Literal, Optional, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.utils import ( from litellm.utils import (
TextCompletionResponse, Choices,
Usage,
CustomStreamWrapper, CustomStreamWrapper,
Message, Message,
Choices, TextCompletionResponse,
Usage,
) )
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.types.llms.databricks import GenericStreamingChunk
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM from .base import BaseLLM
import httpx # type: ignore from .prompt_templates.factory import custom_prompt, prompt_factory
class TextCompletionCodestralError(Exception): class TextCompletionCodestralError(Exception):
@ -329,7 +334,12 @@ class CodestralTextCompletion(BaseLLM):
) -> Union[TextCompletionResponse, CustomStreamWrapper]: ) -> Union[TextCompletionResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers) headers = self._validate_environment(api_key, headers)
completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions" if optional_params.pop("custom_endpoint", None) is True:
completion_url = api_base
else:
completion_url = (
api_base or "https://codestral.mistral.ai/v1/fim/completions"
)
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
@ -426,6 +436,7 @@ class CodestralTextCompletion(BaseLLM):
return _response return _response
### SYNC COMPLETION ### SYNC COMPLETION
else: else:
response = requests.post( response = requests.post(
url=completion_url, url=completion_url,
headers=headers, headers=headers,
@ -464,8 +475,11 @@ class CodestralTextCompletion(BaseLLM):
headers={}, headers={},
) -> TextCompletionResponse: ) -> TextCompletionResponse:
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout)) async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=timeout), concurrent_limit=1
)
try: try:
response = await async_handler.post( response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data) api_base, headers=headers, data=json.dumps(data)
) )

View file

@ -140,10 +140,10 @@ class VertexAIPartnerModels(BaseLLM):
custom_prompt_dict: dict, custom_prompt_dict: dict,
headers: Optional[dict], headers: Optional[dict],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
litellm_params: dict,
vertex_project=None, vertex_project=None,
vertex_location=None, vertex_location=None,
vertex_credentials=None, vertex_credentials=None,
litellm_params=None,
logger_fn=None, logger_fn=None,
acompletion: bool = False, acompletion: bool = False,
client=None, client=None,
@ -154,6 +154,7 @@ class VertexAIPartnerModels(BaseLLM):
from litellm.llms.databricks import DatabricksChatCompletion from litellm.llms.databricks import DatabricksChatCompletion
from litellm.llms.openai import OpenAIChatCompletion from litellm.llms.openai import OpenAIChatCompletion
from litellm.llms.text_completion_codestral import CodestralTextCompletion
from litellm.llms.vertex_httpx import VertexLLM from litellm.llms.vertex_httpx import VertexLLM
except Exception: except Exception:
@ -178,12 +179,7 @@ class VertexAIPartnerModels(BaseLLM):
) )
openai_like_chat_completions = DatabricksChatCompletion() openai_like_chat_completions = DatabricksChatCompletion()
codestral_fim_completions = CodestralTextCompletion()
## Load Config
# config = litellm.VertexAILlama3.get_config()
# for k, v in config.items():
# if k not in optional_params:
# optional_params[k] = v
## CONSTRUCT API BASE ## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False stream: bool = optional_params.get("stream", False) or False
@ -206,6 +202,28 @@ class VertexAIPartnerModels(BaseLLM):
model = model.split("@")[0] model = model.split("@")[0]
if "codestral" in model and litellm_params.get("text_completion") is True:
optional_params["model"] = model
text_completion_model_response = litellm.TextCompletionResponse(
stream=stream
)
return codestral_fim_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=text_completion_model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
encoding=encoding,
)
return openai_like_chat_completions.completion( return openai_like_chat_completions.completion(
model=model, model=model,
messages=messages, messages=messages,

View file

@ -986,6 +986,7 @@ def completion(
output_cost_per_second=output_cost_per_second, output_cost_per_second=output_cost_per_second,
output_cost_per_token=output_cost_per_token, output_cost_per_token=output_cost_per_token,
cooldown_time=cooldown_time, cooldown_time=cooldown_time,
text_completion=kwargs.get("text_completion"),
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,

View file

@ -4104,9 +4104,19 @@ async def test_async_text_completion_chat_model_stream():
# asyncio.run(test_async_text_completion_chat_model_stream()) # asyncio.run(test_async_text_completion_chat_model_stream())
@pytest.mark.parametrize(
"model", ["vertex_ai/codestral@2405", "text-completion-codestral/codestral-2405"] #
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_codestral_fim_api(): async def test_completion_codestral_fim_api(model):
try: try:
if model == "vertex_ai/codestral@2405":
from litellm.tests.test_amazing_vertex_completion import (
load_vertex_ai_credentials,
)
load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
import logging import logging
@ -4114,7 +4124,7 @@ async def test_completion_codestral_fim_api():
verbose_logger.setLevel(level=logging.DEBUG) verbose_logger.setLevel(level=logging.DEBUG)
response = await litellm.atext_completion( response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405", model=model,
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():", prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True", suffix="return True",
temperature=0, temperature=0,
@ -4137,9 +4147,19 @@ async def test_completion_codestral_fim_api():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize(
"model",
["vertex_ai/codestral@2405", "text-completion-codestral/codestral-2405"],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_codestral_fim_api_stream(): async def test_completion_codestral_fim_api_stream(model):
try: try:
if model == "vertex_ai/codestral@2405":
from litellm.tests.test_amazing_vertex_completion import (
load_vertex_ai_credentials,
)
load_vertex_ai_credentials()
import logging import logging
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
@ -4148,7 +4168,7 @@ async def test_completion_codestral_fim_api_stream():
# verbose_logger.setLevel(level=logging.DEBUG) # verbose_logger.setLevel(level=logging.DEBUG)
response = await litellm.atext_completion( response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405", model=model,
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():", prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True", suffix="return True",
temperature=0, temperature=0,

View file

@ -2258,6 +2258,7 @@ def get_litellm_params(
output_cost_per_token=None, output_cost_per_token=None,
output_cost_per_second=None, output_cost_per_second=None,
cooldown_time=None, cooldown_time=None,
text_completion=None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -2281,6 +2282,7 @@ def get_litellm_params(
"output_cost_per_token": output_cost_per_token, "output_cost_per_token": output_cost_per_token,
"output_cost_per_second": output_cost_per_second, "output_cost_per_second": output_cost_per_second,
"cooldown_time": cooldown_time, "cooldown_time": cooldown_time,
"text_completion": text_completion,
} }
return litellm_params return litellm_params
@ -3127,10 +3129,15 @@ def get_optional_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralConfig().map_openai_params( if "codestral" in model:
non_default_params=non_default_params, optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
optional_params=optional_params, non_default_params=non_default_params, optional_params=optional_params
) )
else:
optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -4239,6 +4246,10 @@ def get_supported_openai_params(
return litellm.VertexAILlama3Config().get_supported_openai_params() return litellm.VertexAILlama3Config().get_supported_openai_params()
if model.startswith("mistral"): if model.startswith("mistral"):
return litellm.MistralConfig().get_supported_openai_params() return litellm.MistralConfig().get_supported_openai_params()
if model.startswith("codestral"):
return (
litellm.MistralTextCompletionConfig().get_supported_openai_params()
)
return litellm.VertexAIConfig().get_supported_openai_params() return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()