diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 00f8bb3eb..67666ee92 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -4,6 +4,7 @@ import requests from litellm import logging import time from typing import Callable +from litellm.utils import ModelResponse class AnthropicConstants(Enum): HUMAN_PROMPT = "\n\nHuman:" @@ -36,7 +37,7 @@ class AnthropicLLM: "x-api-key": self.api_key } - def completion(self, model: str, messages: list, model_response: dict, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls + def completion(self, model: str, messages: list, model_response: ModelResponse, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls model = model prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" for message in messages: diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index d3beb6451..30d67727f 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -5,6 +5,7 @@ import requests from litellm import logging import time from typing import Callable +from litellm.utils import ModelResponse class HuggingfaceError(Exception): def __init__(self, status_code, message): @@ -26,7 +27,7 @@ class HuggingfaceRestAPILLM(): if self.api_key != None: self.headers["Authorization"] = f"Bearer {self.api_key}" - def completion(self, model: str, messages: list, custom_api_base: str, model_response: dict, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls + def completion(self, model: str, messages: list, custom_api_base: str, model_response: ModelResponse, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls if custom_api_base: completion_url = custom_api_base elif "HF_API_BASE" in os.environ: diff --git a/litellm/main.py b/litellm/main.py index 5749b8e2c..5361f0337 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -40,7 +40,7 @@ def completion( # model specific optional params # used by text-bison only top_k=40, request_timeout=0, # unused var for old version of OpenAI API - ): + ) -> ModelResponse: try: model_response = ModelResponse() if azure: # this flag is deprecated, remove once notebooks are also updated.