diff --git a/litellm/__init__.py b/litellm/__init__.py index bcf764f835..dc2150270c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -795,11 +795,11 @@ from .llms.openai import ( OpenAIConfig, OpenAITextCompletionConfig, MistralConfig, - MistralTextCompletionConfig, MistralEmbeddingConfig, DeepInfraConfig, AzureAIStudioConfig, ) +from .llms.text_completion_codestral import MistralTextCompletionConfig from .llms.azure import ( AzureOpenAIConfig, AzureOpenAIError, diff --git a/litellm/llms/base.py b/litellm/llms/base.py index 0222d2366c..7e80de9ab1 100644 --- a/litellm/llms/base.py +++ b/litellm/llms/base.py @@ -27,6 +27,25 @@ class BaseLLM: """ return model_response + def process_text_completion_response( + self, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: litellm.utils.TextCompletionResponse, + stream: bool, + logging_obj: Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: list, + print_verbose, + encoding, + ) -> Union[litellm.utils.TextCompletionResponse, litellm.utils.CustomStreamWrapper]: + """ + Helper function to process the response across sync + async completion calls + """ + return model_response + def create_client_session(self): if litellm.client_session: _client_session = litellm.client_session diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 976d1f5668..1f2b836c3a 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -208,85 +208,6 @@ class MistralEmbeddingConfig: return optional_params -class MistralTextCompletionConfig: - """ - Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion - """ - - suffix: Optional[str] = None - temperature: Optional[int] = None - top_p: Optional[float] = None - max_tokens: Optional[int] = None - min_tokens: Optional[int] = None - stream: Optional[bool] = None - random_seed: Optional[int] = None - stop: Optional[str] = None - - def __init__( - self, - suffix: Optional[str] = None, - temperature: Optional[int] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - min_tokens: Optional[int] = None, - stream: Optional[bool] = None, - random_seed: Optional[int] = None, - stop: Optional[str] = None, - ) -> None: - locals_ = locals().copy() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_supported_openai_params(self): - return [ - "suffix", - "temperature", - "top_p", - "max_tokens", - "stream", - "seed", - "stop", - ] - - def map_openai_params(self, non_default_params: dict, optional_params: dict): - for param, value in non_default_params.items(): - if param == "suffix": - optional_params["suffix"] = value - if param == "temperature": - optional_params["temperature"] = value - if param == "top_p": - optional_params["top_p"] = value - if param == "max_tokens": - optional_params["max_tokens"] = value - if param == "stream" and value == True: - optional_params["stream"] = value - if param == "stop": - optional_params["stop"] = value - if param == "seed": - optional_params["extra_body"] = {"random_seed": value} - - return optional_params - - class AzureAIStudioConfig: def get_required_params(self) -> List[ProviderField]: """For a given provider, return it's required fields with a description""" diff --git a/litellm/llms/text_completion_codestral.py b/litellm/llms/text_completion_codestral.py new file mode 100644 index 0000000000..124c840c36 --- /dev/null +++ b/litellm/llms/text_completion_codestral.py @@ -0,0 +1,461 @@ +# What is this? +## Controller file for TextCompletionCodestral Integration - https://codestral.com/ + +from functools import partial +import os, types +import traceback +import json +from enum import Enum +import requests, copy # type: ignore +import time +from typing import Callable, Optional, List, Literal, Union +from litellm.utils import ( + TextCompletionResponse, + Usage, + CustomStreamWrapper, + Message, + Choices, +) +from litellm.litellm_core_utils.core_helpers import map_finish_reason +import litellm +from .prompt_templates.factory import prompt_factory, custom_prompt +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from .base import BaseLLM +import httpx # type: ignore + + +class TextCompletionCodestralError(Exception): + def __init__( + self, + status_code, + message, + request: Optional[httpx.Request] = None, + response: Optional[httpx.Response] = None, + ): + self.status_code = status_code + self.message = message + if request is not None: + self.request = request + else: + self.request = httpx.Request( + method="POST", + url="https://docs.codestral.com/user-guide/inference/rest_api", + ) + if response is not None: + self.response = response + else: + self.response = httpx.Response( + status_code=status_code, request=self.request + ) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +async def make_call( + client: AsyncHTTPHandler, + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + response = await client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise TextCompletionCodestralError( + status_code=response.status_code, message=response.text + ) + + completion_stream = response.aiter_lines() + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_stream, # Pass the completion stream for logging + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + +class MistralTextCompletionConfig: + """ + Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion + """ + + suffix: Optional[str] = None + temperature: Optional[int] = None + top_p: Optional[float] = None + max_tokens: Optional[int] = None + min_tokens: Optional[int] = None + stream: Optional[bool] = None + random_seed: Optional[int] = None + stop: Optional[str] = None + + def __init__( + self, + suffix: Optional[str] = None, + temperature: Optional[int] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + min_tokens: Optional[int] = None, + stream: Optional[bool] = None, + random_seed: Optional[int] = None, + stop: Optional[str] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "suffix", + "temperature", + "top_p", + "max_tokens", + "stream", + "seed", + "stop", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "suffix": + optional_params["suffix"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "stream" and value == True: + optional_params["stream"] = value + if param == "stop": + optional_params["stop"] = value + if param == "seed": + optional_params["random_seed"] = value + if param == "min_tokens": + optional_params["min_tokens"] = value + + return optional_params + + +class CodestralTextCompletion(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def _validate_environment( + self, + api_key: Optional[str], + user_headers: dict, + ) -> dict: + if api_key is None: + raise ValueError( + "Missing CODESTRAL_API_Key - Please add CODESTRAL_API_Key to your environment variables" + ) + headers = { + "content-type": "application/json", + "Authorization": "Bearer {}".format(api_key), + } + if user_headers is not None and isinstance(user_headers, dict): + headers = {**headers, **user_headers} + return headers + + def output_parser(self, generated_text: str): + """ + Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. + + Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763 + """ + chat_template_tokens = [ + "<|assistant|>", + "<|system|>", + "<|user|>", + "", + "", + ] + for token in chat_template_tokens: + if generated_text.strip().startswith(token): + generated_text = generated_text.replace(token, "", 1) + if generated_text.endswith(token): + generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] + return generated_text + + def process_text_completion_response( + self, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: TextCompletionResponse, + stream: bool, + logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: list, + print_verbose, + encoding, + ) -> TextCompletionResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + if response.status_code != 200: + raise TextCompletionCodestralError( + message=str(response.text), + status_code=response.status_code, + ) + try: + completion_response = response.json() + except: + raise TextCompletionCodestralError(message=response.text, status_code=422) + + _response = litellm.TextCompletionResponse(**completion_response) + return _response + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: TextCompletionResponse, + print_verbose: Callable, + encoding, + api_key: str, + logging_obj, + optional_params: dict, + timeout: Union[float, httpx.Timeout], + acompletion=None, + litellm_params=None, + logger_fn=None, + headers: dict = {}, + ) -> Union[TextCompletionResponse, CustomStreamWrapper]: + headers = self._validate_environment(api_key, headers) + + completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions" + + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + prompt = prompt_factory(model=model, messages=messages) + + ## Load Config + config = litellm.MistralTextCompletionConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + stream = optional_params.pop("stream", False) + + data = { + "prompt": prompt, + **optional_params, + } + input_text = prompt + ## LOGGING + logging_obj.pre_call( + input=input_text, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": completion_url, + "acompletion": acompletion, + }, + ) + ## COMPLETION CALL + if acompletion is True: + ### ASYNC STREAMING + if stream is True: + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=completion_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + ) # type: ignore + else: + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=completion_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=False, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + ) # type: ignore + + ### SYNC STREAMING + if stream is True: + response = requests.post( + completion_url, + headers=headers, + data=json.dumps(data), + stream=stream, + ) + _response = CustomStreamWrapper( + response.iter_lines(), + model, + custom_llm_provider="codestral", + logging_obj=logging_obj, + ) + return _response + ### SYNC COMPLETION + else: + response = requests.post( + url=completion_url, + headers=headers, + data=json.dumps(data), + ) + return self.process_text_completion_response( + model=model, + response=response, + model_response=model_response, + stream=optional_params.get("stream", False), + logging_obj=logging_obj, # type: ignore + optional_params=optional_params, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + ) + + async def async_completion( + self, + model: str, + messages: list, + api_base: str, + model_response: TextCompletionResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + data: dict, + optional_params: dict, + timeout: Union[float, httpx.Timeout], + litellm_params=None, + logger_fn=None, + headers={}, + ) -> TextCompletionResponse: + + async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout)) + try: + response = await async_handler.post( + api_base, headers=headers, data=json.dumps(data) + ) + except httpx.HTTPStatusError as e: + raise TextCompletionCodestralError( + status_code=e.response.status_code, + message="HTTPStatusError - {}".format(e.response.text), + ) + except Exception as e: + raise TextCompletionCodestralError( + status_code=500, message="{}\n{}".format(str(e), traceback.format_exc()) + ) + return self.process_text_completion_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) + + async def async_streaming( + self, + model: str, + messages: list, + api_base: str, + model_response: TextCompletionResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + data: dict, + timeout: Union[float, httpx.Timeout], + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, + ) -> CustomStreamWrapper: + data["stream"] = True + + streamwrapper = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_call, + api_base=api_base, + headers=headers, + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="codestral", + logging_obj=logging_obj, + ) + return streamwrapper + + def embedding(self, *args, **kwargs): + pass diff --git a/litellm/main.py b/litellm/main.py index 648802620b..0540d29cde 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -82,6 +82,7 @@ from .llms.predibase import PredibaseChatCompletion from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM from .llms.vertex_httpx import VertexLLM from .llms.triton import TritonChatCompletion +from .llms.text_completion_codestral import CodestralTextCompletion from .llms.prompt_templates.factory import ( prompt_factory, custom_prompt, @@ -120,6 +121,7 @@ azure_chat_completions = AzureChatCompletion() azure_text_completions = AzureTextCompletion() huggingface = Huggingface() predibase_chat_completions = PredibaseChatCompletion() +codestral_text_completions = CodestralTextCompletion() triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() @@ -2027,6 +2029,46 @@ def completion( timeout=timeout, ) + if ( + "stream" in optional_params + and optional_params["stream"] is True + and acompletion is False + ): + return _model_response + response = _model_response + elif custom_llm_provider == "text-completion-codestral": + + api_base = ( + api_base + or optional_params.pop("api_base", None) + or optional_params.pop("base_url", None) + or litellm.api_base + or "https://codestral.mistral.ai/v1/fim/completions" + ) + + api_key = api_key or litellm.api_key or get_secret("CODESTRAL_API_KEY") + + text_completion_model_response = litellm.TextCompletionResponse( + stream=stream + ) + + _model_response = codestral_text_completions.completion( # type: ignore + model=model, + messages=messages, + model_response=text_completion_model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + acompletion=acompletion, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + api_key=api_key, + timeout=timeout, + ) + if ( "stream" in optional_params and optional_params["stream"] is True diff --git a/litellm/tests/test_text_completion.py b/litellm/tests/test_text_completion.py index 3ec3954fde..1ddd8ea6bb 100644 --- a/litellm/tests/test_text_completion.py +++ b/litellm/tests/test_text_completion.py @@ -4078,29 +4078,30 @@ async def test_async_text_completion_chat_model_stream(): # asyncio.run(test_async_text_completion_chat_model_stream()) -# @pytest.mark.asyncio -# async def test_completion_codestral_fim_api(): -# try: -# litellm.set_verbose = True -# from litellm._logging import verbose_logger -# import logging -# verbose_logger.setLevel(level=logging.DEBUG) -# response = await litellm.atext_completion( -# model="text-completion-codestral/codestral-2405", -# prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():", -# suffix="return True", -# temperature=0, -# top_p=0.4, -# max_tokens=10, -# # min_tokens=10, -# seed=10, -# stop=["return"], -# ) -# # Add any assertions here to check the response -# print(response) +@pytest.mark.asyncio +async def test_completion_codestral_fim_api(): + try: + litellm.set_verbose = True + from litellm._logging import verbose_logger + import logging -# # cost = litellm.completion_cost(completion_response=response) -# # print("cost to make mistral completion=", cost) -# # assert cost > 0.0 -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") + verbose_logger.setLevel(level=logging.DEBUG) + response = await litellm.atext_completion( + model="text-completion-codestral/codestral-2405", + prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():", + suffix="return True", + temperature=0, + top_p=1, + max_tokens=10, + min_tokens=10, + seed=10, + stop=["return"], + ) + # Add any assertions here to check the response + print(response) + + # cost = litellm.completion_cost(completion_response=response) + # print("cost to make mistral completion=", cost) + # assert cost > 0.0 + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index 054648825b..f66077d7a4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3017,29 +3017,6 @@ def get_optional_params( optional_params["response_format"] = response_format if seed is not None: optional_params["seed"] = seed - elif custom_llm_provider == "codestral": - # supported_params = get_supported_openai_params( - # model=model, custom_llm_provider=custom_llm_provider - # ) - # _check_valid_arg(supported_params=supported_params) - # optional_params = litellm.DeepInfraConfig().map_openai_params( - # non_default_params=non_default_params, - # optional_params=optional_params, - # model=model, - # ) - pass - elif custom_llm_provider == "text-completion-codestral": - # supported_params = get_supported_openai_params( - # model=model, custom_llm_provider=custom_llm_provider - # ) - # _check_valid_arg(supported_params=supported_params) - # optional_params = litellm.DeepInfraConfig().map_openai_params( - # non_default_params=non_default_params, - # optional_params=optional_params, - # model=model, - # ) - pass - elif custom_llm_provider == "deepseek": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -3906,10 +3883,6 @@ def get_llm_provider( # codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1 api_base = "https://codestral.mistral.ai/v1" dynamic_api_key = get_secret("CODESTRAL_API_KEY") - elif custom_llm_provider == "text-completion-codestral": - # codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1 - api_base = "https://codestral.mistral.ai/v1" - dynamic_api_key = get_secret("CODESTRAL_API_KEY") elif custom_llm_provider == "deepseek": # deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1 api_base = "https://api.deepseek.com/v1"