fix(main.py): fix translation to text_completions format for async text completion calls

This commit is contained in:
Krrish Dholakia 2024-03-30 09:02:51 -07:00
parent 89471ba4c5
commit c0204310ee
3 changed files with 41 additions and 9 deletions

View file

@ -2952,7 +2952,26 @@ async def atext_completion(*args, **kwargs):
model=model, model=model,
) )
else: else:
return response transformed_logprobs = None
# only supported for TGI models
try:
raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e:
print_verbose(f"LiteLLM non blocking exception: {e}")
text_completion_response = TextCompletionResponse()
text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None)
text_completion_response["model"] = response.get("model", None)
text_choices = TextChoices()
text_choices["text"] = response["choices"][0]["message"]["content"]
text_choices["index"] = response["choices"][0]["index"]
text_choices["logprobs"] = transformed_logprobs
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
text_completion_response["choices"] = [text_choices]
text_completion_response["usage"] = response.get("usage", None)
return text_completion_response
except Exception as e: except Exception as e:
custom_llm_provider = custom_llm_provider or "openai" custom_llm_provider = custom_llm_provider or "openai"
raise exception_type( raise exception_type(
@ -3165,6 +3184,7 @@ def text_completion(
transformed_logprobs = litellm.utils.transform_logprobs(raw_response) transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM non blocking exception: {e}") print_verbose(f"LiteLLM non blocking exception: {e}")
text_completion_response["id"] = response.get("id", None) text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion" text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None) text_completion_response["created"] = response.get("created", None)
@ -3176,6 +3196,7 @@ def text_completion(
text_choices["finish_reason"] = response["choices"][0]["finish_reason"] text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
text_completion_response["choices"] = [text_choices] text_completion_response["choices"] = [text_choices]
text_completion_response["usage"] = response.get("usage", None) text_completion_response["usage"] = response.get("usage", None)
return text_completion_response return text_completion_response

View file

@ -45,8 +45,8 @@ model_list:
litellm_settings: litellm_settings:
drop_params: True drop_params: True
max_budget: 100 # max_budget: 100
budget_duration: 30d # budget_duration: 30d
num_retries: 5 num_retries: 5
request_timeout: 600 request_timeout: 600
telemetry: False telemetry: False

View file

@ -2,7 +2,8 @@
## Tests /chat/completions by generating a key and then making a chat completions request ## Tests /chat/completions by generating a key and then making a chat completions request
import pytest import pytest
import asyncio import asyncio
import aiohttp import aiohttp, openai
from openai import OpenAI
async def generate_key(session): async def generate_key(session):
@ -114,14 +115,14 @@ async def completion(session, key):
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
status = response.status status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200: if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}") raise Exception(f"Request did not return a 200 status code: {status}")
response = await response.json()
return response
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion(): async def test_completion():
@ -137,7 +138,17 @@ async def test_completion():
await completion(session=session, key=key) await completion(session=session, key=key)
key_gen = await new_user(session=session) key_gen = await new_user(session=session)
key_2 = key_gen["key"] key_2 = key_gen["key"]
await completion(session=session, key=key_2) # response = await completion(session=session, key=key_2)
## validate openai format ##
client = OpenAI(api_key=key_2, base_url="http://0.0.0.0:4000")
client.completions.create(
model="gpt-4",
prompt="Say this is a test",
max_tokens=7,
temperature=0,
)
async def embeddings(session, key): async def embeddings(session, key):