feat(vertex_ai_partner.py): add vertex ai codestral FIM support

Closes https://github.com/BerriAI/litellm/issues/4984
This commit is contained in:
Krrish Dholakia 2024-08-01 17:10:27 -07:00
parent 2121738137
commit cb9b19e887
5 changed files with 96 additions and 32 deletions

View file

@ -1,28 +1,33 @@
# What is this?
## Controller file for TextCompletionCodestral Integration - https://codestral.com/
from functools import partial
import os, types
import traceback
import copy
import json
from enum import Enum
import requests, copy # type: ignore
import os
import time
from typing import Callable, Optional, List, Literal, Union
import traceback
import types
from enum import Enum
from functools import partial
from typing import Callable, List, Literal, Optional, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.utils import (
TextCompletionResponse,
Usage,
Choices,
CustomStreamWrapper,
Message,
Choices,
TextCompletionResponse,
Usage,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.types.llms.databricks import GenericStreamingChunk
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from .prompt_templates.factory import custom_prompt, prompt_factory
class TextCompletionCodestralError(Exception):
@ -329,7 +334,12 @@ class CodestralTextCompletion(BaseLLM):
) -> Union[TextCompletionResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers)
completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions"
if optional_params.pop("custom_endpoint", None) is True:
completion_url = api_base
else:
completion_url = (
api_base or "https://codestral.mistral.ai/v1/fim/completions"
)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
@ -426,6 +436,7 @@ class CodestralTextCompletion(BaseLLM):
return _response
### SYNC COMPLETION
else:
response = requests.post(
url=completion_url,
headers=headers,
@ -464,8 +475,11 @@ class CodestralTextCompletion(BaseLLM):
headers={},
) -> TextCompletionResponse:
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=timeout), concurrent_limit=1
)
try:
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)

View file

@ -140,10 +140,10 @@ class VertexAIPartnerModels(BaseLLM):
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
client=None,
@ -154,6 +154,7 @@ class VertexAIPartnerModels(BaseLLM):
from litellm.llms.databricks import DatabricksChatCompletion
from litellm.llms.openai import OpenAIChatCompletion
from litellm.llms.text_completion_codestral import CodestralTextCompletion
from litellm.llms.vertex_httpx import VertexLLM
except Exception:
@ -178,12 +179,7 @@ class VertexAIPartnerModels(BaseLLM):
)
openai_like_chat_completions = DatabricksChatCompletion()
## Load Config
# config = litellm.VertexAILlama3.get_config()
# for k, v in config.items():
# if k not in optional_params:
# optional_params[k] = v
codestral_fim_completions = CodestralTextCompletion()
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
@ -206,6 +202,28 @@ class VertexAIPartnerModels(BaseLLM):
model = model.split("@")[0]
if "codestral" in model and litellm_params.get("text_completion") is True:
optional_params["model"] = model
text_completion_model_response = litellm.TextCompletionResponse(
stream=stream
)
return codestral_fim_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=text_completion_model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
encoding=encoding,
)
return openai_like_chat_completions.completion(
model=model,
messages=messages,

View file

@ -986,6 +986,7 @@ def completion(
output_cost_per_second=output_cost_per_second,
output_cost_per_token=output_cost_per_token,
cooldown_time=cooldown_time,
text_completion=kwargs.get("text_completion"),
)
logging.update_environment_variables(
model=model,

View file

@ -4104,9 +4104,19 @@ async def test_async_text_completion_chat_model_stream():
# asyncio.run(test_async_text_completion_chat_model_stream())
@pytest.mark.parametrize(
"model", ["vertex_ai/codestral@2405", "text-completion-codestral/codestral-2405"] #
)
@pytest.mark.asyncio
async def test_completion_codestral_fim_api():
async def test_completion_codestral_fim_api(model):
try:
if model == "vertex_ai/codestral@2405":
from litellm.tests.test_amazing_vertex_completion import (
load_vertex_ai_credentials,
)
load_vertex_ai_credentials()
litellm.set_verbose = True
import logging
@ -4114,7 +4124,7 @@ async def test_completion_codestral_fim_api():
verbose_logger.setLevel(level=logging.DEBUG)
response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405",
model=model,
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True",
temperature=0,
@ -4137,9 +4147,19 @@ async def test_completion_codestral_fim_api():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize(
"model",
["vertex_ai/codestral@2405", "text-completion-codestral/codestral-2405"],
)
@pytest.mark.asyncio
async def test_completion_codestral_fim_api_stream():
async def test_completion_codestral_fim_api_stream(model):
try:
if model == "vertex_ai/codestral@2405":
from litellm.tests.test_amazing_vertex_completion import (
load_vertex_ai_credentials,
)
load_vertex_ai_credentials()
import logging
from litellm._logging import verbose_logger
@ -4148,7 +4168,7 @@ async def test_completion_codestral_fim_api_stream():
# verbose_logger.setLevel(level=logging.DEBUG)
response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405",
model=model,
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True",
temperature=0,

View file

@ -2258,6 +2258,7 @@ def get_litellm_params(
output_cost_per_token=None,
output_cost_per_second=None,
cooldown_time=None,
text_completion=None,
):
litellm_params = {
"acompletion": acompletion,
@ -2281,6 +2282,7 @@ def get_litellm_params(
"output_cost_per_token": output_cost_per_token,
"output_cost_per_second": output_cost_per_second,
"cooldown_time": cooldown_time,
"text_completion": text_completion,
}
return litellm_params
@ -3127,10 +3129,15 @@ def get_optional_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
if "codestral" in model:
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
)
else:
optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -4239,6 +4246,10 @@ def get_supported_openai_params(
return litellm.VertexAILlama3Config().get_supported_openai_params()
if model.startswith("mistral"):
return litellm.MistralConfig().get_supported_openai_params()
if model.startswith("codestral"):
return (
litellm.MistralTextCompletionConfig().get_supported_openai_params()
)
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()