diff --git a/litellm/llms/text_completion_codestral.py b/litellm/llms/text_completion_codestral.py index e732706b4a..7c758f5b5f 100644 --- a/litellm/llms/text_completion_codestral.py +++ b/litellm/llms/text_completion_codestral.py @@ -1,28 +1,33 @@ # What is this? ## Controller file for TextCompletionCodestral Integration - https://codestral.com/ -from functools import partial -import os, types -import traceback +import copy import json -from enum import Enum -import requests, copy # type: ignore +import os import time -from typing import Callable, Optional, List, Literal, Union +import traceback +import types +from enum import Enum +from functools import partial +from typing import Callable, List, Literal, Optional, Union + +import httpx # type: ignore +import requests # type: ignore + +import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.types.llms.databricks import GenericStreamingChunk from litellm.utils import ( - TextCompletionResponse, - Usage, + Choices, CustomStreamWrapper, Message, - Choices, + TextCompletionResponse, + Usage, ) -from litellm.litellm_core_utils.core_helpers import map_finish_reason -from litellm.types.llms.databricks import GenericStreamingChunk -import litellm -from .prompt_templates.factory import prompt_factory, custom_prompt -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + from .base import BaseLLM -import httpx # type: ignore +from .prompt_templates.factory import custom_prompt, prompt_factory class TextCompletionCodestralError(Exception): @@ -329,7 +334,12 @@ class CodestralTextCompletion(BaseLLM): ) -> Union[TextCompletionResponse, CustomStreamWrapper]: headers = self._validate_environment(api_key, headers) - completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions" + if optional_params.pop("custom_endpoint", None) is True: + completion_url = api_base + else: + completion_url = ( + api_base or "https://codestral.mistral.ai/v1/fim/completions" + ) if model in custom_prompt_dict: # check if the model has a registered custom prompt @@ -426,6 +436,7 @@ class CodestralTextCompletion(BaseLLM): return _response ### SYNC COMPLETION else: + response = requests.post( url=completion_url, headers=headers, @@ -464,8 +475,11 @@ class CodestralTextCompletion(BaseLLM): headers={}, ) -> TextCompletionResponse: - async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout)) + async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=timeout), concurrent_limit=1 + ) try: + response = await async_handler.post( api_base, headers=headers, data=json.dumps(data) ) diff --git a/litellm/llms/vertex_ai_partner.py b/litellm/llms/vertex_ai_partner.py index 7642029173..08780be765 100644 --- a/litellm/llms/vertex_ai_partner.py +++ b/litellm/llms/vertex_ai_partner.py @@ -140,10 +140,10 @@ class VertexAIPartnerModels(BaseLLM): custom_prompt_dict: dict, headers: Optional[dict], timeout: Union[float, httpx.Timeout], + litellm_params: dict, vertex_project=None, vertex_location=None, vertex_credentials=None, - litellm_params=None, logger_fn=None, acompletion: bool = False, client=None, @@ -154,6 +154,7 @@ class VertexAIPartnerModels(BaseLLM): from litellm.llms.databricks import DatabricksChatCompletion from litellm.llms.openai import OpenAIChatCompletion + from litellm.llms.text_completion_codestral import CodestralTextCompletion from litellm.llms.vertex_httpx import VertexLLM except Exception: @@ -178,12 +179,7 @@ class VertexAIPartnerModels(BaseLLM): ) openai_like_chat_completions = DatabricksChatCompletion() - - ## Load Config - # config = litellm.VertexAILlama3.get_config() - # for k, v in config.items(): - # if k not in optional_params: - # optional_params[k] = v + codestral_fim_completions = CodestralTextCompletion() ## CONSTRUCT API BASE stream: bool = optional_params.get("stream", False) or False @@ -206,6 +202,28 @@ class VertexAIPartnerModels(BaseLLM): model = model.split("@")[0] + if "codestral" in model and litellm_params.get("text_completion") is True: + optional_params["model"] = model + text_completion_model_response = litellm.TextCompletionResponse( + stream=stream + ) + return codestral_fim_completions.completion( + model=model, + messages=messages, + api_base=api_base, + api_key=access_token, + custom_prompt_dict=custom_prompt_dict, + model_response=text_completion_model_response, + print_verbose=print_verbose, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, + encoding=encoding, + ) + return openai_like_chat_completions.completion( model=model, messages=messages, diff --git a/litellm/main.py b/litellm/main.py index b01fba5b24..fb598ff8e7 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -986,6 +986,7 @@ def completion( output_cost_per_second=output_cost_per_second, output_cost_per_token=output_cost_per_token, cooldown_time=cooldown_time, + text_completion=kwargs.get("text_completion"), ) logging.update_environment_variables( model=model, diff --git a/litellm/tests/test_text_completion.py b/litellm/tests/test_text_completion.py index 6a0080b373..618a6095f0 100644 --- a/litellm/tests/test_text_completion.py +++ b/litellm/tests/test_text_completion.py @@ -4104,9 +4104,19 @@ async def test_async_text_completion_chat_model_stream(): # asyncio.run(test_async_text_completion_chat_model_stream()) +@pytest.mark.parametrize( + "model", ["vertex_ai/codestral@2405", "text-completion-codestral/codestral-2405"] # +) @pytest.mark.asyncio -async def test_completion_codestral_fim_api(): +async def test_completion_codestral_fim_api(model): try: + if model == "vertex_ai/codestral@2405": + from litellm.tests.test_amazing_vertex_completion import ( + load_vertex_ai_credentials, + ) + + load_vertex_ai_credentials() + litellm.set_verbose = True import logging @@ -4114,7 +4124,7 @@ async def test_completion_codestral_fim_api(): verbose_logger.setLevel(level=logging.DEBUG) response = await litellm.atext_completion( - model="text-completion-codestral/codestral-2405", + model=model, prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():", suffix="return True", temperature=0, @@ -4137,9 +4147,19 @@ async def test_completion_codestral_fim_api(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize( + "model", + ["vertex_ai/codestral@2405", "text-completion-codestral/codestral-2405"], +) @pytest.mark.asyncio -async def test_completion_codestral_fim_api_stream(): +async def test_completion_codestral_fim_api_stream(model): try: + if model == "vertex_ai/codestral@2405": + from litellm.tests.test_amazing_vertex_completion import ( + load_vertex_ai_credentials, + ) + + load_vertex_ai_credentials() import logging from litellm._logging import verbose_logger @@ -4148,7 +4168,7 @@ async def test_completion_codestral_fim_api_stream(): # verbose_logger.setLevel(level=logging.DEBUG) response = await litellm.atext_completion( - model="text-completion-codestral/codestral-2405", + model=model, prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():", suffix="return True", temperature=0, diff --git a/litellm/utils.py b/litellm/utils.py index 0e1573784b..da02362c04 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2258,6 +2258,7 @@ def get_litellm_params( output_cost_per_token=None, output_cost_per_second=None, cooldown_time=None, + text_completion=None, ): litellm_params = { "acompletion": acompletion, @@ -2281,6 +2282,7 @@ def get_litellm_params( "output_cost_per_token": output_cost_per_token, "output_cost_per_second": output_cost_per_second, "cooldown_time": cooldown_time, + "text_completion": text_completion, } return litellm_params @@ -3127,10 +3129,15 @@ def get_optional_params( model=model, custom_llm_provider=custom_llm_provider ) _check_valid_arg(supported_params=supported_params) - optional_params = litellm.MistralConfig().map_openai_params( - non_default_params=non_default_params, - optional_params=optional_params, - ) + if "codestral" in model: + optional_params = litellm.MistralTextCompletionConfig().map_openai_params( + non_default_params=non_default_params, optional_params=optional_params + ) + else: + optional_params = litellm.MistralConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) elif custom_llm_provider == "sagemaker": ## check if unsupported param passed in supported_params = get_supported_openai_params( @@ -4239,6 +4246,10 @@ def get_supported_openai_params( return litellm.VertexAILlama3Config().get_supported_openai_params() if model.startswith("mistral"): return litellm.MistralConfig().get_supported_openai_params() + if model.startswith("codestral"): + return ( + litellm.MistralTextCompletionConfig().get_supported_openai_params() + ) return litellm.VertexAIConfig().get_supported_openai_params() elif request_type == "embeddings": return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()