feat - add fim codestral api

This commit is contained in:
Ishaan Jaff 2024-06-17 13:46:03 -07:00
parent ad47fee181
commit 364492297d
7 changed files with 549 additions and 132 deletions

View file

@ -795,11 +795,11 @@ from .llms.openai import (
OpenAIConfig, OpenAIConfig,
OpenAITextCompletionConfig, OpenAITextCompletionConfig,
MistralConfig, MistralConfig,
MistralTextCompletionConfig,
MistralEmbeddingConfig, MistralEmbeddingConfig,
DeepInfraConfig, DeepInfraConfig,
AzureAIStudioConfig, AzureAIStudioConfig,
) )
from .llms.text_completion_codestral import MistralTextCompletionConfig
from .llms.azure import ( from .llms.azure import (
AzureOpenAIConfig, AzureOpenAIConfig,
AzureOpenAIError, AzureOpenAIError,

View file

@ -27,6 +27,25 @@ class BaseLLM:
""" """
return model_response return model_response
def process_text_completion_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.TextCompletionResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
) -> Union[litellm.utils.TextCompletionResponse, litellm.utils.CustomStreamWrapper]:
"""
Helper function to process the response across sync + async completion calls
"""
return model_response
def create_client_session(self): def create_client_session(self):
if litellm.client_session: if litellm.client_session:
_client_session = litellm.client_session _client_session = litellm.client_session

View file

@ -208,85 +208,6 @@ class MistralEmbeddingConfig:
return optional_params return optional_params
class MistralTextCompletionConfig:
"""
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
"""
suffix: Optional[str] = None
temperature: Optional[int] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = None
random_seed: Optional[int] = None
stop: Optional[str] = None
def __init__(
self,
suffix: Optional[str] = None,
temperature: Optional[int] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stream: Optional[bool] = None,
random_seed: Optional[int] = None,
stop: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"suffix",
"temperature",
"top_p",
"max_tokens",
"stream",
"seed",
"stop",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "suffix":
optional_params["suffix"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
return optional_params
class AzureAIStudioConfig: class AzureAIStudioConfig:
def get_required_params(self) -> List[ProviderField]: def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description""" """For a given provider, return it's required fields with a description"""

View file

@ -0,0 +1,461 @@
# What is this?
## Controller file for TextCompletionCodestral Integration - https://codestral.com/
from functools import partial
import os, types
import traceback
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List, Literal, Union
from litellm.utils import (
TextCompletionResponse,
Usage,
CustomStreamWrapper,
Message,
Choices,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
class TextCompletionCodestralError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request is not None:
self.request = request
else:
self.request = httpx.Request(
method="POST",
url="https://docs.codestral.com/user-guide/inference/rest_api",
)
if response is not None:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
async def make_call(
client: AsyncHTTPHandler,
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise TextCompletionCodestralError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class MistralTextCompletionConfig:
"""
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
"""
suffix: Optional[str] = None
temperature: Optional[int] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = None
random_seed: Optional[int] = None
stop: Optional[str] = None
def __init__(
self,
suffix: Optional[str] = None,
temperature: Optional[int] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stream: Optional[bool] = None,
random_seed: Optional[int] = None,
stop: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"suffix",
"temperature",
"top_p",
"max_tokens",
"stream",
"seed",
"stop",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "suffix":
optional_params["suffix"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "seed":
optional_params["random_seed"] = value
if param == "min_tokens":
optional_params["min_tokens"] = value
return optional_params
class CodestralTextCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _validate_environment(
self,
api_key: Optional[str],
user_headers: dict,
) -> dict:
if api_key is None:
raise ValueError(
"Missing CODESTRAL_API_Key - Please add CODESTRAL_API_Key to your environment variables"
)
headers = {
"content-type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
if user_headers is not None and isinstance(user_headers, dict):
headers = {**headers, **user_headers}
return headers
def output_parser(self, generated_text: str):
"""
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
"""
chat_template_tokens = [
"<|assistant|>",
"<|system|>",
"<|user|>",
"<s>",
"</s>",
]
for token in chat_template_tokens:
if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
def process_text_completion_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: TextCompletionResponse,
stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
) -> TextCompletionResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
if response.status_code != 200:
raise TextCompletionCodestralError(
message=str(response.text),
status_code=response.status_code,
)
try:
completion_response = response.json()
except:
raise TextCompletionCodestralError(message=response.text, status_code=422)
_response = litellm.TextCompletionResponse(**completion_response)
return _response
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: TextCompletionResponse,
print_verbose: Callable,
encoding,
api_key: str,
logging_obj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: dict = {},
) -> Union[TextCompletionResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers)
completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions"
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
## Load Config
config = litellm.MistralTextCompletionConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
stream = optional_params.pop("stream", False)
data = {
"prompt": prompt,
**optional_params,
}
input_text = prompt
## LOGGING
logging_obj.pre_call(
input=input_text,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL
if acompletion is True:
### ASYNC STREAMING
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=completion_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
) # type: ignore
else:
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=completion_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
) # type: ignore
### SYNC STREAMING
if stream is True:
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=stream,
)
_response = CustomStreamWrapper(
response.iter_lines(),
model,
custom_llm_provider="codestral",
logging_obj=logging_obj,
)
return _response
### SYNC COMPLETION
else:
response = requests.post(
url=completion_url,
headers=headers,
data=json.dumps(data),
)
return self.process_text_completion_response(
model=model,
response=response,
model_response=model_response,
stream=optional_params.get("stream", False),
logging_obj=logging_obj, # type: ignore
optional_params=optional_params,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: TextCompletionResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
data: dict,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params=None,
logger_fn=None,
headers={},
) -> TextCompletionResponse:
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
try:
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
except httpx.HTTPStatusError as e:
raise TextCompletionCodestralError(
status_code=e.response.status_code,
message="HTTPStatusError - {}".format(e.response.text),
)
except Exception as e:
raise TextCompletionCodestralError(
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
)
return self.process_text_completion_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: TextCompletionResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
data: dict,
timeout: Union[float, httpx.Timeout],
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
) -> CustomStreamWrapper:
data["stream"] = True
streamwrapper = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="codestral",
logging_obj=logging_obj,
)
return streamwrapper
def embedding(self, *args, **kwargs):
pass

View file

@ -82,6 +82,7 @@ from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
@ -120,6 +121,7 @@ azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion() azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion() predibase_chat_completions = PredibaseChatCompletion()
codestral_text_completions = CodestralTextCompletion()
triton_chat_completions = TritonChatCompletion() triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
@ -2027,6 +2029,46 @@ def completion(
timeout=timeout, timeout=timeout,
) )
if (
"stream" in optional_params
and optional_params["stream"] is True
and acompletion is False
):
return _model_response
response = _model_response
elif custom_llm_provider == "text-completion-codestral":
api_base = (
api_base
or optional_params.pop("api_base", None)
or optional_params.pop("base_url", None)
or litellm.api_base
or "https://codestral.mistral.ai/v1/fim/completions"
)
api_key = api_key or litellm.api_key or get_secret("CODESTRAL_API_KEY")
text_completion_model_response = litellm.TextCompletionResponse(
stream=stream
)
_model_response = codestral_text_completions.completion( # type: ignore
model=model,
messages=messages,
model_response=text_completion_model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
api_key=api_key,
timeout=timeout,
)
if ( if (
"stream" in optional_params "stream" in optional_params
and optional_params["stream"] is True and optional_params["stream"] is True

View file

@ -4078,29 +4078,30 @@ async def test_async_text_completion_chat_model_stream():
# asyncio.run(test_async_text_completion_chat_model_stream()) # asyncio.run(test_async_text_completion_chat_model_stream())
# @pytest.mark.asyncio @pytest.mark.asyncio
# async def test_completion_codestral_fim_api(): async def test_completion_codestral_fim_api():
# try: try:
# litellm.set_verbose = True litellm.set_verbose = True
# from litellm._logging import verbose_logger from litellm._logging import verbose_logger
# import logging import logging
# verbose_logger.setLevel(level=logging.DEBUG)
# response = await litellm.atext_completion(
# model="text-completion-codestral/codestral-2405",
# prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
# suffix="return True",
# temperature=0,
# top_p=0.4,
# max_tokens=10,
# # min_tokens=10,
# seed=10,
# stop=["return"],
# )
# # Add any assertions here to check the response
# print(response)
# # cost = litellm.completion_cost(completion_response=response) verbose_logger.setLevel(level=logging.DEBUG)
# # print("cost to make mistral completion=", cost) response = await litellm.atext_completion(
# # assert cost > 0.0 model="text-completion-codestral/codestral-2405",
# except Exception as e: prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
# pytest.fail(f"Error occurred: {e}") suffix="return True",
temperature=0,
top_p=1,
max_tokens=10,
min_tokens=10,
seed=10,
stop=["return"],
)
# Add any assertions here to check the response
print(response)
# cost = litellm.completion_cost(completion_response=response)
# print("cost to make mistral completion=", cost)
# assert cost > 0.0
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -3017,29 +3017,6 @@ def get_optional_params(
optional_params["response_format"] = response_format optional_params["response_format"] = response_format
if seed is not None: if seed is not None:
optional_params["seed"] = seed optional_params["seed"] = seed
elif custom_llm_provider == "codestral":
# supported_params = get_supported_openai_params(
# model=model, custom_llm_provider=custom_llm_provider
# )
# _check_valid_arg(supported_params=supported_params)
# optional_params = litellm.DeepInfraConfig().map_openai_params(
# non_default_params=non_default_params,
# optional_params=optional_params,
# model=model,
# )
pass
elif custom_llm_provider == "text-completion-codestral":
# supported_params = get_supported_openai_params(
# model=model, custom_llm_provider=custom_llm_provider
# )
# _check_valid_arg(supported_params=supported_params)
# optional_params = litellm.DeepInfraConfig().map_openai_params(
# non_default_params=non_default_params,
# optional_params=optional_params,
# model=model,
# )
pass
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -3906,10 +3883,6 @@ def get_llm_provider(
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1 # codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
api_base = "https://codestral.mistral.ai/v1" api_base = "https://codestral.mistral.ai/v1"
dynamic_api_key = get_secret("CODESTRAL_API_KEY") dynamic_api_key = get_secret("CODESTRAL_API_KEY")
elif custom_llm_provider == "text-completion-codestral":
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
api_base = "https://codestral.mistral.ai/v1"
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1 # deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
api_base = "https://api.deepseek.com/v1" api_base = "https://api.deepseek.com/v1"