From c0204310eea8196547eb521cf377f937865dd125 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Mar 2024 09:02:51 -0700 Subject: [PATCH] fix(main.py): fix translation to text_completions format for async text completion calls --- litellm/main.py | 23 ++++++++++++++++++++++- proxy_server_config.yaml | 4 ++-- tests/test_openai_endpoints.py | 23 +++++++++++++++++------ 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 665784f1d..1fcf0d5d3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2952,7 +2952,26 @@ async def atext_completion(*args, **kwargs): model=model, ) else: - return response + transformed_logprobs = None + # only supported for TGI models + try: + raw_response = response._hidden_params.get("original_response", None) + transformed_logprobs = litellm.utils.transform_logprobs(raw_response) + except Exception as e: + print_verbose(f"LiteLLM non blocking exception: {e}") + text_completion_response = TextCompletionResponse() + text_completion_response["id"] = response.get("id", None) + text_completion_response["object"] = "text_completion" + text_completion_response["created"] = response.get("created", None) + text_completion_response["model"] = response.get("model", None) + text_choices = TextChoices() + text_choices["text"] = response["choices"][0]["message"]["content"] + text_choices["index"] = response["choices"][0]["index"] + text_choices["logprobs"] = transformed_logprobs + text_choices["finish_reason"] = response["choices"][0]["finish_reason"] + text_completion_response["choices"] = [text_choices] + text_completion_response["usage"] = response.get("usage", None) + return text_completion_response except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( @@ -3165,6 +3184,7 @@ def text_completion( transformed_logprobs = litellm.utils.transform_logprobs(raw_response) except Exception as e: print_verbose(f"LiteLLM non blocking exception: {e}") + text_completion_response["id"] = response.get("id", None) text_completion_response["object"] = "text_completion" text_completion_response["created"] = response.get("created", None) @@ -3176,6 +3196,7 @@ def text_completion( text_choices["finish_reason"] = response["choices"][0]["finish_reason"] text_completion_response["choices"] = [text_choices] text_completion_response["usage"] = response.get("usage", None) + return text_completion_response diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 30033b28b..089c1e95c 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -45,8 +45,8 @@ model_list: litellm_settings: drop_params: True - max_budget: 100 - budget_duration: 30d + # max_budget: 100 + # budget_duration: 30d num_retries: 5 request_timeout: 600 telemetry: False diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 432f2f1ab..9535b4842 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -2,7 +2,8 @@ ## Tests /chat/completions by generating a key and then making a chat completions request import pytest import asyncio -import aiohttp +import aiohttp, openai +from openai import OpenAI async def generate_key(session): @@ -114,14 +115,14 @@ async def completion(session, key): async with session.post(url, headers=headers, json=data) as response: status = response.status - response_text = await response.text() - - print(response_text) - print() if status != 200: raise Exception(f"Request did not return a 200 status code: {status}") + response = await response.json() + + return response + @pytest.mark.asyncio async def test_completion(): @@ -137,7 +138,17 @@ async def test_completion(): await completion(session=session, key=key) key_gen = await new_user(session=session) key_2 = key_gen["key"] - await completion(session=session, key=key_2) + # response = await completion(session=session, key=key_2) + + ## validate openai format ## + client = OpenAI(api_key=key_2, base_url="http://0.0.0.0:4000") + + client.completions.create( + model="gpt-4", + prompt="Say this is a test", + max_tokens=7, + temperature=0, + ) async def embeddings(session, key):