feat add text completion config for mistral text

This commit is contained in:
Ishaan Jaff 2024-06-17 12:48:46 -07:00
parent 5f76f96e4d
commit ad47fee181
5 changed files with 120 additions and 19 deletions

View file

@ -404,7 +404,6 @@ openai_compatible_providers: List = [
"mistral", "mistral",
"groq", "groq",
"codestral", "codestral",
"text-completion-codestral",
"deepseek", "deepseek",
"deepinfra", "deepinfra",
"perplexity", "perplexity",
@ -796,6 +795,7 @@ from .llms.openai import (
OpenAIConfig, OpenAIConfig,
OpenAITextCompletionConfig, OpenAITextCompletionConfig,
MistralConfig, MistralConfig,
MistralTextCompletionConfig,
MistralEmbeddingConfig, MistralEmbeddingConfig,
DeepInfraConfig, DeepInfraConfig,
AzureAIStudioConfig, AzureAIStudioConfig,

View file

@ -208,6 +208,85 @@ class MistralEmbeddingConfig:
return optional_params return optional_params
class MistralTextCompletionConfig:
"""
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
"""
suffix: Optional[str] = None
temperature: Optional[int] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = None
random_seed: Optional[int] = None
stop: Optional[str] = None
def __init__(
self,
suffix: Optional[str] = None,
temperature: Optional[int] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stream: Optional[bool] = None,
random_seed: Optional[int] = None,
stop: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"suffix",
"temperature",
"top_p",
"max_tokens",
"stream",
"seed",
"stop",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "suffix":
optional_params["suffix"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
return optional_params
class AzureAIStudioConfig: class AzureAIStudioConfig:
def get_required_params(self) -> List[ProviderField]: def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description""" """For a given provider, return it's required fields with a description"""

View file

@ -1049,7 +1049,6 @@ def completion(
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq" or custom_llm_provider == "groq"
or custom_llm_provider == "codestral" or custom_llm_provider == "codestral"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "deepseek" or custom_llm_provider == "deepseek"
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral" or custom_llm_provider == "mistral"
@ -3711,6 +3710,7 @@ def text_completion(
custom_llm_provider == "openai" custom_llm_provider == "openai"
or custom_llm_provider == "azure" or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text" or custom_llm_provider == "azure_text"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
) )
and isinstance(prompt, list) and isinstance(prompt, list)

View file

@ -4078,19 +4078,29 @@ async def test_async_text_completion_chat_model_stream():
# asyncio.run(test_async_text_completion_chat_model_stream()) # asyncio.run(test_async_text_completion_chat_model_stream())
@pytest.mark.asyncio # @pytest.mark.asyncio
async def test_completion_codestral_fim_api(): # async def test_completion_codestral_fim_api():
try: # try:
litellm.set_verbose = True # litellm.set_verbose = True
response = await litellm.atext_completion( # from litellm._logging import verbose_logger
model="text-completion-codestral/codestral-2405", # import logging
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():", # verbose_logger.setLevel(level=logging.DEBUG)
) # response = await litellm.atext_completion(
# Add any assertions here to check the response # model="text-completion-codestral/codestral-2405",
print(response) # prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
# suffix="return True",
# temperature=0,
# top_p=0.4,
# max_tokens=10,
# # min_tokens=10,
# seed=10,
# stop=["return"],
# )
# # Add any assertions here to check the response
# print(response)
# cost = litellm.completion_cost(completion_response=response) # # cost = litellm.completion_cost(completion_response=response)
# print("cost to make mistral completion=", cost) # # print("cost to make mistral completion=", cost)
# assert cost > 0.0 # # assert cost > 0.0
except Exception as e: # except Exception as e:
pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")

View file

@ -2968,7 +2968,7 @@ def get_optional_params(
optional_params["stream"] = stream optional_params["stream"] = stream
if max_tokens: if max_tokens:
optional_params["max_tokens"] = max_tokens optional_params["max_tokens"] = max_tokens
elif custom_llm_provider == "mistral": elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
@ -2976,6 +2976,15 @@ def get_optional_params(
optional_params = litellm.MistralConfig().map_openai_params( optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params non_default_params=non_default_params, optional_params=optional_params
) )
elif custom_llm_provider == "text-completion-codestral":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
)
elif custom_llm_provider == "databricks": elif custom_llm_provider == "databricks":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -3649,11 +3658,14 @@ def get_supported_openai_params(
"tool_choice", "tool_choice",
"max_retries", "max_retries",
] ]
elif custom_llm_provider == "mistral": elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
# mistal and codestral api have the exact same params
if request_type == "chat_completion": if request_type == "chat_completion":
return litellm.MistralConfig().get_supported_openai_params() return litellm.MistralConfig().get_supported_openai_params()
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.MistralEmbeddingConfig().get_supported_openai_params() return litellm.MistralEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "text-completion-codestral":
return litellm.MistralTextCompletionConfig().get_supported_openai_params()
elif custom_llm_provider == "replicate": elif custom_llm_provider == "replicate":
return [ return [
"stream", "stream",