Merge pull request #679 from mc-marcocheng/router-kwargs

Fix data being overwritten
This commit is contained in:
Krish Dholakia 2023-10-24 08:27:23 -07:00 committed by GitHub
commit d244978247
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 15 deletions

View file

@ -61,10 +61,8 @@ class Router:
data = deployment["litellm_params"] data = deployment["litellm_params"]
data["messages"] = messages data["messages"] = messages
for key, value in kwargs.items():
data[key] = value
# call via litellm.completion() # call via litellm.completion()
return litellm.completion(**data) return litellm.completion(**{**data, **kwargs})
def text_completion(self, def text_completion(self,
model: str, model: str,
@ -80,10 +78,8 @@ class Router:
data = deployment["litellm_params"] data = deployment["litellm_params"]
data["prompt"] = prompt data["prompt"] = prompt
for key, value in kwargs.items():
data[key] = value
# call via litellm.completion() # call via litellm.completion()
return litellm.text_completion(**data) return litellm.text_completion(**{**data, **kwargs})
def embedding(self, def embedding(self,
model: str, model: str,
@ -96,7 +92,7 @@ class Router:
data = deployment["litellm_params"] data = deployment["litellm_params"]
data["input"] = input data["input"] = input
# call via litellm.embedding() # call via litellm.embedding()
return litellm.embedding(**data) return litellm.embedding(**{**data, **kwargs})
def set_model_list(self, model_list: list): def set_model_list(self, model_list: list):
self.model_list = model_list self.model_list = model_list

View file

@ -71,7 +71,7 @@ def test_function_calling():
"model_name": "gpt-3.5-turbo-0613", "model_name": "gpt-3.5-turbo-0613",
"litellm_params": { "litellm_params": {
"model": "gpt-3.5-turbo-0613", "model": "gpt-3.5-turbo-0613",
"api_key": "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", "api_key": os.getenv("OPENAI_API_KEY"),
}, },
"tpm": 100000, "tpm": 100000,
"rpm": 10000, "rpm": 10000,
@ -106,4 +106,48 @@ def test_function_calling():
response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions) response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions)
print(response) print(response)
test_function_calling() ### FUNCTION CALLING -> NORMAL COMPLETION
def test_litellm_params_not_overwritten_by_function_calling():
try:
model_list = [
{
"model_name": "gpt-3.5-turbo-0613",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 10000,
},
]
messages = [
{"role": "user", "content": "What is the weather like in Boston?"}
]
functions = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
]
router = Router(model_list=model_list)
_ = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions)
response = router.completion(model="gpt-3.5-turbo-0613", messages=messages)
assert response.choices[0].finish_reason != "function_call"
except Exception as e:
pytest.fail(f"Error occurred: {e}")