diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index be268cfe4..1fe305e41 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ca1ee6f23..28bd22223 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -118,6 +118,66 @@ def data_generator(response): print_verbose(f"returned chunk: {chunk}") yield f"data: {json.dumps(chunk)}\n\n" +def litellm_completion(data, type): + try: + if user_model: + data["model"] = user_model + # override with user settings + if user_temperature: + data["temperature"] = user_temperature + if user_max_tokens: + data["max_tokens"] = user_max_tokens + if user_api_base: + data["api_base"] = user_api_base + ## CUSTOM PROMPT TEMPLATE ## - run `litellm --config` to set this + litellm.register_prompt_template( + model=user_model, + roles={ + "system": { + "pre_message": os.getenv("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""), + "post_message": os.getenv("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""), + }, + "assistant": { + "pre_message": os.getenv("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""), + "post_message": os.getenv("MODEL_ASSISTANT_MESSAGE_END_TOKEN", "") + }, + "user": { + "pre_message": os.getenv("MODEL_USER_MESSAGE_START_TOKEN", ""), + "post_message": os.getenv("MODEL_USER_MESSAGE_END_TOKEN", "") + } + }, + initial_prompt_value=os.getenv("MODEL_PRE_PROMPT", ""), + final_prompt_value=os.getenv("MODEL_POST_PROMPT", "") + ) + if type == "completion": + response = litellm.text_completion(**data) + elif type == "chat_completion": + response = litellm.completion(**data) + if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses + return StreamingResponse(data_generator(response), media_type='text/event-stream') + print_verbose(f"response: {response}") + return response + except Exception as e: + if "Invalid response object from API" in str(e): + completion_call_details = {} + if user_model: + completion_call_details["model"] = user_model + else: + completion_call_details["model"] = data['model'] + + if user_api_base: + completion_call_details["api_base"] = user_api_base + else: + completion_call_details["api_base"] = None + print(f"\033[1;31mLiteLLM.Exception: Invalid API Call. Call details: Model: \033[1;37m{completion_call_details['model']}\033[1;31m; LLM Provider: \033[1;37m{e.llm_provider}\033[1;31m; Custom API Base - \033[1;37m{completion_call_details['api_base']}\033[1;31m\033[0m") + if completion_call_details["api_base"] == "http://localhost:11434": + print() + print("Trying to call ollama? Try `litellm --model ollama/llama2 --api_base http://localhost:11434`") + print() + else: + print(f"\033[1;31mLiteLLM.Exception: {str(e)}\033[0m") + return {"message": "An error occurred"}, 500 + #### API ENDPOINTS #### @router.get("/models") # if project requires model list def model_list(): @@ -136,82 +196,15 @@ def model_list(): @router.post("/completions") async def completion(request: Request): data = await request.json() - print_verbose(f"data passed in: {data}") - if user_model: - data["model"] = user_model - if user_api_base: - data["api_base"] = user_api_base - # override with user settings - if user_temperature: - data["temperature"] = user_temperature - if user_max_tokens: - data["max_tokens"] = user_max_tokens - - ## check for custom prompt template ## - litellm.register_prompt_template( - model=user_model, - roles={ - "system": { - "pre_message": os.getenv("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""), - }, - "assistant": { - "pre_message": os.getenv("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_ASSISTANT_MESSAGE_END_TOKEN", "") - }, - "user": { - "pre_message": os.getenv("MODEL_USER_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_USER_MESSAGE_END_TOKEN", "") - } - }, - initial_prompt_value=os.getenv("MODEL_PRE_PROMPT", ""), - final_prompt_value=os.getenv("MODEL_POST_PROMPT", "") - ) - response = litellm.text_completion(**data) - if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses - return StreamingResponse(data_generator(response), media_type='text/event-stream') - return response + return litellm_completion(data=data, type="completion") @router.post("/chat/completions") async def chat_completion(request: Request): data = await request.json() print_verbose(f"data passed in: {data}") - if user_model: - data["model"] = user_model - # override with user settings - if user_temperature: - data["temperature"] = user_temperature - if user_max_tokens: - data["max_tokens"] = user_max_tokens - if user_api_base: - data["api_base"] = user_api_base - ## check for custom prompt template ## - litellm.register_prompt_template( - model=user_model, - roles={ - "system": { - "pre_message": os.getenv("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""), - }, - "assistant": { - "pre_message": os.getenv("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_ASSISTANT_MESSAGE_END_TOKEN", "") - }, - "user": { - "pre_message": os.getenv("MODEL_USER_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_USER_MESSAGE_END_TOKEN", "") - } - }, - initial_prompt_value=os.getenv("MODEL_PRE_PROMPT", ""), - final_prompt_value=os.getenv("MODEL_POST_PROMPT", "") - ) - response = litellm.completion(**data) - + response = litellm_completion(data, type="chat_completion") # track cost of this response, using litellm.completion_cost - await track_cost(response) - if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses - return StreamingResponse(data_generator(response), media_type='text/event-stream') - print_verbose(f"response: {response}") + track_cost(response) return response async def track_cost(response): diff --git a/litellm/tests/test_bad_params.py b/litellm/tests/test_bad_params.py index b563b02af..886c139f3 100644 --- a/litellm/tests/test_bad_params.py +++ b/litellm/tests/test_bad_params.py @@ -54,17 +54,19 @@ def test_completion_invalid_param_cohere(): else: pytest.fail(f'An error occurred {e}') -test_completion_invalid_param_cohere() +# test_completion_invalid_param_cohere()s def test_completion_function_call_cohere(): try: - response = completion(model="command-nightly", messages=messages, function_call="TEST-FUNCTION") + response = completion(model="command-nightly", messages=messages, functions=["TEST-FUNCTION"]) except Exception as e: if "Function calling is not supported by this provider" in str(e): pass else: pytest.fail(f'An error occurred {e}') +test_completion_function_call_cohere() + def test_completion_function_call_openai(): try: messages = [{"role": "user", "content": "What is the weather like in Boston?"}]