Merge pull request #144 from yujonglee/fix-modelresponse-typing

Fix ModelResponse typing
This commit is contained in:
Krish Dholakia 2023-08-18 04:46:42 -07:00 committed by GitHub
commit 3d1934a829
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 8 additions and 11 deletions

View file

@ -4,6 +4,7 @@ import requests
from litellm import logging from litellm import logging
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman:" HUMAN_PROMPT = "\n\nHuman:"
@ -36,7 +37,7 @@ class AnthropicLLM:
"x-api-key": self.api_key "x-api-key": self.api_key
} }
def completion(self, model: str, messages: list, model_response: dict, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls def completion(self, model: str, messages: list, model_response: ModelResponse, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls
model = model model = model
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}"
for message in messages: for message in messages:

View file

@ -5,6 +5,7 @@ import requests
from litellm import logging from litellm import logging
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -26,7 +27,7 @@ class HuggingfaceRestAPILLM():
if self.api_key != None: if self.api_key != None:
self.headers["Authorization"] = f"Bearer {self.api_key}" self.headers["Authorization"] = f"Bearer {self.api_key}"
def completion(self, model: str, messages: list, custom_api_base: str, model_response: dict, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls def completion(self, model: str, messages: list, custom_api_base: str, model_response: ModelResponse, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls
if custom_api_base: if custom_api_base:
completion_url = custom_api_base completion_url = custom_api_base
elif "HF_API_BASE" in os.environ: elif "HF_API_BASE" in os.environ:

View file

@ -40,7 +40,7 @@ def completion(
# model specific optional params # model specific optional params
# used by text-bison only # used by text-bison only
top_k=40, request_timeout=0, # unused var for old version of OpenAI API top_k=40, request_timeout=0, # unused var for old version of OpenAI API
): ) -> ModelResponse:
try: try:
model_response = ModelResponse() model_response = ModelResponse()
if azure: # this flag is deprecated, remove once notebooks are also updated. if azure: # this flag is deprecated, remove once notebooks are also updated.

View file

@ -94,14 +94,9 @@ def test_completion_openai():
response_str = response['choices'][0]['message']['content'] response_str = response['choices'][0]['message']['content']
response_str_2 = response.choices[0].message.content response_str_2 = response.choices[0].message.content
print(response_str) assert response_str == response_str_2
print(response_str_2) assert type(response_str) == str
if type(response_str) != str: assert len(response_str) > 1
pytest.fail(f"Error occurred: {e}")
if type(response_str_2) != str:
pytest.fail(f"Error occurred: {e}")
# Add any assertions here to check the response
print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")