test_litellm_params_not_overwritten_by_function_calling

This commit is contained in:
mc-marcocheng 2023-10-24 22:04:45 +08:00
parent 8dad2eec83
commit 3c28ff6167

View file

@ -108,43 +108,46 @@ def test_function_calling():
### FUNCTION CALLING -> NORMAL COMPLETION ### FUNCTION CALLING -> NORMAL COMPLETION
def test_litellm_params_not_overwritten_by_function_calling(): def test_litellm_params_not_overwritten_by_function_calling():
model_list = [ try:
{ model_list = [
"model_name": "gpt-3.5-turbo-0613", {
"litellm_params": { "model_name": "gpt-3.5-turbo-0613",
"model": "gpt-3.5-turbo-0613", "litellm_params": {
"api_key": os.getenv("OPENAI_API_KEY"), "model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 10000,
}, },
"tpm": 100000, ]
"rpm": 10000,
},
]
messages = [ messages = [
{"role": "user", "content": "What is the weather like in Boston?"} {"role": "user", "content": "What is the weather like in Boston?"}
] ]
functions = [ functions = [
{ {
"name": "get_current_weather", "name": "get_current_weather",
"description": "Get the current weather in a given location", "description": "Get the current weather in a given location",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA" "description": "The city and state, e.g. San Francisco, CA"
}, },
"unit": { "unit": {
"type": "string", "type": "string",
"enum": ["celsius", "fahrenheit"] "enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
} }
}, }
"required": ["location"] ]
}
}
]
router = Router(model_list=model_list) router = Router(model_list=model_list)
_ = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions) _ = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions)
response = router.completion(model="gpt-3.5-turbo-0613", messages=messages) response = router.completion(model="gpt-3.5-turbo-0613", messages=messages)
assert response.choices[0].finish_reason != "function_call" assert response.choices[0].finish_reason != "function_call"
except Exception as e:
pytest.fail(f"Error occurred: {e}")