diff --git a/litellm/router.py b/litellm/router.py index dd6303a948..a58f01fc12 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -242,22 +242,80 @@ class Router: ### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS def completion( - self, model: str, messages: List[Dict[str, str]], **kwargs + self, + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + messages: List = [], + functions: Optional[List] = None, + function_call: Optional[str] = None, + timeout: Optional[Union[float, int]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream: Optional[bool] = None, + stop=None, + max_tokens: Optional[float] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + # openai v1.0+ new params + response_format: Optional[dict] = None, + seed: Optional[int] = None, + tools: Optional[List] = None, + tool_choice: Optional[str] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + deployment_id=None, + # set api_base, api_version, api_key + base_url: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + # Optional liteLLM function params + **kwargs, ) -> Union[ModelResponse, CustomStreamWrapper]: """ Example usage: response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}] """ try: - kwargs["model"] = model - kwargs["messages"] = messages - kwargs["original_function"] = self._completion - timeout = kwargs.get("request_timeout", self.timeout) + completion_kwargs = { + "model": model, + "messages": messages, + "functions": functions, + "function_call": function_call, + "timeout": timeout, + "temperature": temperature, + "top_p": top_p, + "n": n, + "stream": stream, + "stop": stop, + "max_tokens": max_tokens, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "user": user, + "response_format": response_format, + "seed": seed, + "tools": tools, + "tool_choice": tool_choice, + "logprobs": logprobs, + "top_logprobs": top_logprobs, + "deployment_id": deployment_id, + "base_url": base_url, + "api_version": api_version, + "api_key": api_key, + "model_list": model_list, + "original_function": self._completion, + } kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs.setdefault("metadata", {}).update({"model_group": model}) with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: # Submit the function to the executor with a timeout - future = executor.submit(self.function_with_fallbacks, **kwargs) + future = executor.submit( + self.function_with_fallbacks, **kwargs, **completion_kwargs + ) response = future.result(timeout=timeout) # type: ignore return response