## Uses the huggingface text generation inference API import os, json from enum import Enum import requests from litellm import logging import time from typing import Callable class HuggingfaceError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message super().__init__(self.message) # Call the base class constructor with the parameters it needs class HuggingfaceRestAPILLM(): def __init__(self, encoding, api_key=None) -> None: self.encoding = encoding self.validate_environment(api_key=api_key) def validate_environment(self, api_key): # set up the environment required to run the model self.headers = { "content-type": "application/json", } # get the api key if it exists in the environment or is passed in, but don't require it self.api_key = os.getenv("HF_TOKEN") if "HF_TOKEN" in os.environ else api_key if self.api_key != None: self.headers["Authorization"] = f"Bearer {self.api_key}" def completion(self, model: str, messages: list, custom_api_base: str, model_response: dict, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls if custom_api_base: completion_url = custom_api_base elif "HF_API_BASE" in os.environ: completion_url = os.getenv("HF_API_BASE") else: completion_url = f"https://api-inference.huggingface.co/models/{model}" prompt = "" if "meta-llama" in model and "chat" in model: # use the required special tokens for meta-llama - https://huggingface.co/blog/llama2#how-to-prompt-llama-2 prompt = "" for message in messages: if message["role"] == "system": prompt += "[INST] <>" + message["content"] elif message["role"] == "assistant": prompt += message["content"] + "[INST]" elif message["role"] == "user": prompt += message["content"] + "[/INST]" else: for message in messages: prompt += f"{message['content']}" ### MAP INPUT PARAMS # max tokens if "max_tokens" in optional_params: value = optional_params.pop("max_tokens") optional_params["max_new_tokens"] = value data = { "inputs": prompt, # "parameters": optional_params } ## LOGGING logging(model=model, input=prompt, additional_args={"litellm_params": litellm_params, "optional_params": optional_params}, logger_fn=logger_fn) ## COMPLETION CALL response = requests.post(completion_url, headers=self.headers, data=json.dumps(data)) if "stream" in optional_params and optional_params["stream"] == True: return response.iter_lines() else: ## LOGGING logging(model=model, input=prompt, additional_args={"litellm_params": litellm_params, "optional_params": optional_params, "original_response": response.text}, logger_fn=logger_fn) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT completion_response = response.json() print_verbose(f"response: {completion_response}") if isinstance(completion_response, dict) and "error" in completion_response: print_verbose(f"completion error: {completion_response['error']}") print_verbose(f"response.status_code: {response.status_code}") raise HuggingfaceError(message=completion_response["error"], status_code=response.status_code) else: model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] ## CALCULATING USAGE prompt_tokens = len(self.encoding.encode(prompt)) ##[TODO] use the llama2 tokenizer here completion_tokens = len(self.encoding.encode(model_response["choices"][0]["message"]["content"])) ##[TODO] use the llama2 tokenizer here model_response["created"] = time.time() model_response["model"] = model model_response["usage"] = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens } return model_response pass def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass