diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f8cb9e96e..43e9bc002 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -113,6 +113,46 @@ app.add_middleware( allow_headers=["*"], ) + +from typing import Dict +from pydantic import BaseModel, Extra +######### Request Class Definition ###### +class ChatCompletionRequest(BaseModel): + model: str + messages: List[Dict[str, str]] + temperature: Optional[float] = None + top_p: Optional[float] = None + n: Optional[int] = None + stream: Optional[bool] = None + stop: Optional[List[str]] = None + max_tokens: Optional[float] = None + presence_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + response_format: Optional[Dict[str, str]] = None + seed: Optional[int] = None + tools: Optional[List[str]] = None + tool_choice: Optional[str] = None + functions: List[str] = None # soon to be deprecated + function_call: Optional[str] = None # soon to be deprecated + + # Optional LiteLLM params + caching: Optional[bool] = None + api_base: Optional[str] = None + api_version: Optional[str] = None + api_key: Optional[str] = None + num_retries: Optional[int] = None + context_window_fallback_dict: Optional[Dict[str, str]] = None + fallbacks: Optional[List[str]] = None + metadata: Optional[Dict[str, str]] = {} + deployment_id: Optional[str] = None + request_timeout: Optional[int] = None + + class Config: + extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs) + + user_api_base = None user_model = None user_debug = False @@ -707,21 +747,18 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key status_code=status, detail=error_msg ) - + @router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint -async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)) -> litellm.ModelResponse: +async def chat_completion(request: ChatCompletionRequest, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)) -> litellm.ModelResponse: global general_settings, user_debug try: data = {} - body = await request.body() - body_str = body.decode() - try: - data = ast.literal_eval(body_str) - except: - data = json.loads(body_str) + request_items = request.model_dump() + data = {key: value for key, value in request_items.items() if value is not None} # pydantic sets all values to None, filter out None values here + print_verbose(f"receiving data: {data}") data["model"] = ( general_settings.get("completion_model", None) # server default