From 0590ec620b311f42936b01db70ca9ad11f01dc73 Mon Sep 17 00:00:00 2001 From: mc-marcocheng <57459045+mc-marcocheng@users.noreply.github.com> Date: Tue, 24 Oct 2023 16:02:15 +0800 Subject: [PATCH 1/3] Fix data being overwritten --- litellm/router.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index a1f9fcfc7..796cbb985 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -61,10 +61,8 @@ class Router: data = deployment["litellm_params"] data["messages"] = messages - for key, value in kwargs.items(): - data[key] = value # call via litellm.completion() - return litellm.completion(**data) + return litellm.completion(**{**data, **kwargs}) def text_completion(self, model: str, @@ -80,10 +78,8 @@ class Router: data = deployment["litellm_params"] data["prompt"] = prompt - for key, value in kwargs.items(): - data[key] = value # call via litellm.completion() - return litellm.text_completion(**data) + return litellm.text_completion(**{**data, **kwargs}) def embedding(self, model: str, @@ -96,7 +92,7 @@ class Router: data = deployment["litellm_params"] data["input"] = input # call via litellm.embedding() - return litellm.embedding(**data) + return litellm.embedding(**{**data, **kwargs}) def set_model_list(self, model_list: list): self.model_list = model_list From 8dad2eec83ddc5acdafa501a0da84f1c1b8c085a Mon Sep 17 00:00:00 2001 From: mc-marcocheng Date: Tue, 24 Oct 2023 22:03:04 +0800 Subject: [PATCH 2/3] test_litellm_params_not_overwritten_by_function_calling --- litellm/tests/test_router.py | 57 +++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 4206c6325..5f6892cd1 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -68,14 +68,14 @@ def test_function_calling(): litellm.set_verbose =True model_list = [ { - "model_name": "gpt-3.5-turbo-0613", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": "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", - }, - "tpm": 100000, - "rpm": 10000, + "model_name": "gpt-3.5-turbo-0613", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), }, + "tpm": 100000, + "rpm": 10000, + }, ] messages = [ @@ -106,4 +106,45 @@ def test_function_calling(): response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions) print(response) -test_function_calling() \ No newline at end of file +### FUNCTION CALLING -> NORMAL COMPLETION +def test_litellm_params_not_overwritten_by_function_calling(): + model_list = [ + { + "model_name": "gpt-3.5-turbo-0613", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + + messages = [ + {"role": "user", "content": "What is the weather like in Boston?"} + ] + functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + ] + + router = Router(model_list=model_list) + _ = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions) + response = router.completion(model="gpt-3.5-turbo-0613", messages=messages) + assert response.choices[0].finish_reason != "function_call" From 3c28ff616708a3ebb7ae285594724e5b00b715b7 Mon Sep 17 00:00:00 2001 From: mc-marcocheng Date: Tue, 24 Oct 2023 22:04:45 +0800 Subject: [PATCH 3/3] test_litellm_params_not_overwritten_by_function_calling --- litellm/tests/test_router.py | 75 +++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 5f6892cd1..a2a311c5a 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -108,43 +108,46 @@ def test_function_calling(): ### FUNCTION CALLING -> NORMAL COMPLETION def test_litellm_params_not_overwritten_by_function_calling(): - model_list = [ - { - "model_name": "gpt-3.5-turbo-0613", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": os.getenv("OPENAI_API_KEY"), + try: + model_list = [ + { + "model_name": "gpt-3.5-turbo-0613", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 100000, + "rpm": 10000, }, - "tpm": 100000, - "rpm": 10000, - }, - ] + ] - messages = [ - {"role": "user", "content": "What is the weather like in Boston?"} - ] - functions = [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] + messages = [ + {"role": "user", "content": "What is the weather like in Boston?"} + ] + functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] } - }, - "required": ["location"] - } - } - ] + } + ] - router = Router(model_list=model_list) - _ = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions) - response = router.completion(model="gpt-3.5-turbo-0613", messages=messages) - assert response.choices[0].finish_reason != "function_call" + router = Router(model_list=model_list) + _ = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions) + response = router.completion(model="gpt-3.5-turbo-0613", messages=messages) + assert response.choices[0].finish_reason != "function_call" + except Exception as e: + pytest.fail(f"Error occurred: {e}")