diff --git a/litellm/router.py b/litellm/router.py index a1f9fcfc7..796cbb985 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -61,10 +61,8 @@ class Router: data = deployment["litellm_params"] data["messages"] = messages - for key, value in kwargs.items(): - data[key] = value # call via litellm.completion() - return litellm.completion(**data) + return litellm.completion(**{**data, **kwargs}) def text_completion(self, model: str, @@ -80,10 +78,8 @@ class Router: data = deployment["litellm_params"] data["prompt"] = prompt - for key, value in kwargs.items(): - data[key] = value # call via litellm.completion() - return litellm.text_completion(**data) + return litellm.text_completion(**{**data, **kwargs}) def embedding(self, model: str, @@ -96,7 +92,7 @@ class Router: data = deployment["litellm_params"] data["input"] = input # call via litellm.embedding() - return litellm.embedding(**data) + return litellm.embedding(**{**data, **kwargs}) def set_model_list(self, model_list: list): self.model_list = model_list diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 4206c6325..a2a311c5a 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -68,14 +68,14 @@ def test_function_calling(): litellm.set_verbose =True model_list = [ { - "model_name": "gpt-3.5-turbo-0613", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": "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", - }, - "tpm": 100000, - "rpm": 10000, + "model_name": "gpt-3.5-turbo-0613", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), }, + "tpm": 100000, + "rpm": 10000, + }, ] messages = [ @@ -106,4 +106,48 @@ def test_function_calling(): response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions) print(response) -test_function_calling() \ No newline at end of file +### FUNCTION CALLING -> NORMAL COMPLETION +def test_litellm_params_not_overwritten_by_function_calling(): + try: + model_list = [ + { + "model_name": "gpt-3.5-turbo-0613", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + + messages = [ + {"role": "user", "content": "What is the weather like in Boston?"} + ] + functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + ] + + router = Router(model_list=model_list) + _ = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions) + response = router.completion(model="gpt-3.5-turbo-0613", messages=messages) + assert response.choices[0].finish_reason != "function_call" + except Exception as e: + pytest.fail(f"Error occurred: {e}")