fix(openai.py): creat MistralConfig with response_format mapping for mistral api

This commit is contained in:
Krrish Dholakia 2024-05-13 13:29:43 -07:00
parent 20fe4ffd6b
commit 20456968e9
5 changed files with 129 additions and 46 deletions

View file

@ -755,7 +755,7 @@ from .llms.bedrock import (
AmazonMistralConfig, AmazonMistralConfig,
AmazonBedrockGlobalConfig, AmazonBedrockGlobalConfig,
) )
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig, MistralConfig
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
from .llms.watsonx import IBMWatsonXAIConfig from .llms.watsonx import IBMWatsonXAIConfig
from .main import * # type: ignore from .main import * # type: ignore

View file

@ -53,6 +53,113 @@ class OpenAIError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class MistralConfig:
"""
Reference: https://docs.mistral.ai/api/
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
"""
temperature: Optional[int] = None
top_p: Optional[int] = None
max_tokens: Optional[int] = None
tools: Optional[list] = None
tool_choice: Optional[Literal["auto", "any", "none"]] = None
random_seed: Optional[int] = None
safe_prompt: Optional[bool] = None
response_format: Optional[dict] = None
def __init__(
self,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
max_tokens: Optional[int] = None,
tools: Optional[list] = None,
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
random_seed: Optional[int] = None,
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
"seed",
"response_format",
]
def _map_tool_choice(self, tool_choice: str) -> str:
if tool_choice == "auto" or tool_choice == "none":
return tool_choice
elif tool_choice == "required":
return "any"
else: # openai 'tool_choice' object param not supported by Mistral API
return "any"
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "tool_choice" and isinstance(value, str):
optional_params["tool_choice"] = self._map_tool_choice(
tool_choice=value
)
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
return optional_params
class OpenAIConfig: class OpenAIConfig:
""" """
Reference: https://platform.openai.com/docs/api-reference/chat/create Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -1327,8 +1434,8 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client, client=client,
) )
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data thread_id, **message_data # type: ignore
) )
response_obj: Optional[OpenAIMessage] = None response_obj: Optional[OpenAIMessage] = None
@ -1458,7 +1565,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client, client=client,
) )
response = openai_client.beta.threads.runs.create_and_poll( response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id, thread_id=thread_id,
assistant_id=assistant_id, assistant_id=assistant_id,
additional_instructions=additional_instructions, additional_instructions=additional_instructions,

View file

@ -665,6 +665,7 @@ def test_completion_mistral_api():
"content": "Hey, how's it going?", "content": "Hey, how's it going?",
} }
], ],
seed=10,
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)

View file

@ -37,14 +37,19 @@ def get_current_weather(location, unit="fahrenheit"):
# Example dummy function hard coded to return the same weather # Example dummy function hard coded to return the same weather
# In production, this could be your backend API or an external API # In production, this could be your backend API or an external API
def test_parallel_function_call(): @pytest.mark.parametrize(
"model", ["gpt-3.5-turbo-1106", "mistral/mistral-large-latest"]
)
def test_parallel_function_call(model):
try: try:
# Step 1: send the conversation and available functions to the model # Step 1: send the conversation and available functions to the model
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris?", "content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses",
} }
] ]
tools = [ tools = [
@ -58,7 +63,7 @@ def test_parallel_function_call():
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA", "description": "The city and state",
}, },
"unit": { "unit": {
"type": "string", "type": "string",
@ -71,7 +76,7 @@ def test_parallel_function_call():
} }
] ]
response = litellm.completion( response = litellm.completion(
model="gpt-3.5-turbo-1106", model=model,
messages=messages, messages=messages,
tools=tools, tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit tool_choice="auto", # auto is default, but we'll be explicit
@ -83,8 +88,8 @@ def test_parallel_function_call():
print("length of tool calls", len(tool_calls)) print("length of tool calls", len(tool_calls))
print("Expecting there to be 3 tool calls") print("Expecting there to be 3 tool calls")
assert ( assert (
len(tool_calls) > 1 len(tool_calls) > 0
) # this has to call the function for SF, Tokyo and parise ) # this has to call the function for SF, Tokyo and paris
# Step 2: check if the model wanted to call a function # Step 2: check if the model wanted to call a function
if tool_calls: if tool_calls:
@ -116,7 +121,7 @@ def test_parallel_function_call():
) # extend conversation with function response ) # extend conversation with function response
print(f"messages: {messages}") print(f"messages: {messages}")
second_response = litellm.completion( second_response = litellm.completion(
model="gpt-3.5-turbo-1106", messages=messages, temperature=0.2, seed=22 model=model, messages=messages, temperature=0.2, seed=22
) # get a new response from the model where it can see the function response ) # get a new response from the model where it can see the function response
print("second response\n", second_response) print("second response\n", second_response)
return second_response return second_response

View file

@ -5617,32 +5617,9 @@ def get_optional_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if temperature is not None: optional_params = litellm.MistralConfig().map_openai_params(
optional_params["temperature"] = temperature non_default_params=non_default_params, optional_params=optional_params
if top_p is not None:
optional_params["top_p"] = top_p
if stream is not None:
optional_params["stream"] = stream
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if tools is not None:
optional_params["tools"] = tools
if tool_choice is not None:
optional_params["tool_choice"] = tool_choice
if response_format is not None:
optional_params["response_format"] = response_format
# check safe_mode, random_seed: https://docs.mistral.ai/api/#operation/createChatCompletion
safe_mode = passed_params.pop("safe_mode", None)
random_seed = passed_params.pop("random_seed", None)
extra_body = {}
if safe_mode is not None:
extra_body["safe_mode"] = safe_mode
if random_seed is not None:
extra_body["random_seed"] = random_seed
optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param
) )
elif custom_llm_provider == "groq": elif custom_llm_provider == "groq":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -5843,7 +5820,8 @@ def get_optional_params(
for k in passed_params.keys(): for k in passed_params.keys():
if k not in default_params.keys(): if k not in default_params.keys():
extra_body[k] = passed_params[k] extra_body[k] = passed_params[k]
optional_params["extra_body"] = extra_body optional_params.setdefault("extra_body", {})
optional_params["extra_body"] = {**optional_params["extra_body"], **extra_body}
else: else:
# if user passed in non-default kwargs for specific providers/models, pass them along # if user passed in non-default kwargs for specific providers/models, pass them along
for k in passed_params.keys(): for k in passed_params.keys():
@ -6212,15 +6190,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"max_retries", "max_retries",
] ]
elif custom_llm_provider == "mistral": elif custom_llm_provider == "mistral":
return [ return litellm.MistralConfig().get_supported_openai_params()
"temperature",
"top_p",
"stream",
"max_tokens",
"tools",
"tool_choice",
"response_format",
]
elif custom_llm_provider == "replicate": elif custom_llm_provider == "replicate":
return [ return [
"stream", "stream",