fix(utils.py): support azure mistral function calling

This commit is contained in:
Krrish Dholakia 2024-04-17 19:10:12 -07:00
parent 9c7179e66f
commit 18e3cf8bff
3 changed files with 80 additions and 1 deletions

View file

@ -3,7 +3,8 @@ model_list:
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
# api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_base: http://0.0.0.0:8080
stream_timeout: 0.001
rpm: 10
- litellm_params:

View file

@ -33,6 +33,21 @@ def reset_callbacks():
litellm.callbacks = []
def test_response_model_none():
"""
Addresses - https://github.com/BerriAI/litellm/issues/2972
"""
x = completion(
model="mymodel",
custom_llm_provider="openai",
messages=[{"role": "user", "content": "Hello!"}],
api_base="http://0.0.0.0:8080",
api_key="my-api-key",
)
print(f"x: {x}")
assert isinstance(x, litellm.ModelResponse)
def test_completion_custom_provider_model_name():
try:
litellm.cache = None
@ -399,6 +414,51 @@ def test_completion_claude_3_function_plus_image():
print(response)
def test_completion_azure_mistral_large_function_calling():
"""
This primarily tests if the 'Function()' pydantic object correctly handles argument param passed in as a dict vs. string
"""
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
response = completion(
model="azure/mistral-large-latest",
api_base=os.getenv("AZURE_MISTRAL_API_BASE"),
api_key=os.getenv("AZURE_MISTRAL_API_KEY"),
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(response.choices[0].message.tool_calls[0].function.arguments, str)
def test_completion_mistral_api():
try:
litellm.set_verbose = True

View file

@ -226,6 +226,24 @@ class Function(OpenAIObject):
arguments: str
name: Optional[str] = None
def __init__(
self,
arguments: Union[Dict, str],
name: Optional[str] = None,
**params,
):
if isinstance(arguments, Dict):
arguments = json.dumps(arguments)
else:
arguments = arguments
name = name
# Build a dictionary with the structure your BaseModel expects
data = {"arguments": arguments, "name": name, **params}
super(Function, self).__init__(**data)
class ChatCompletionDeltaToolCall(OpenAIObject):
id: Optional[str] = None