refactor: add black formatting

This commit is contained in:
Krrish Dholakia 2023-12-25 14:10:38 +05:30
parent b87d630b0a
commit 4905929de3
156 changed files with 19723 additions and 10869 deletions

View file

@ -11,32 +11,47 @@ from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper,
from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request is not None:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api-inference.huggingface.co/models")
else:
self.request = httpx.Request(
method="POST", url="https://api-inference.huggingface.co/models"
)
if response is not None:
self.response = response
else:
self.response = httpx.Response(status_code=status_code, request=self.request)
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class HuggingfaceConfig():
class HuggingfaceConfig:
"""
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
"""
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
return_full_text: Optional[bool] = False # by default don't return the input as part of the output
return_full_text: Optional[
bool
] = False # by default don't return the input as part of the output
seed: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
@ -46,50 +61,66 @@ class HuggingfaceConfig():
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None
) -> None:
def __init__(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def output_parser(generated_text: str):
def output_parser(generated_text: str):
"""
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
"""
chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
for token in chat_template_tokens:
for token in chat_template_tokens:
if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
tgi_models_cache = None
conv_models_cache = None
def read_tgi_conv_models():
try:
global tgi_models_cache, conv_models_cache
@ -101,30 +132,38 @@ def read_tgi_conv_models():
tgi_models = set()
script_directory = os.path.dirname(os.path.abspath(__file__))
# Construct the file path relative to the script's directory
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_text_generation_models.txt")
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_text_generation_models.txt",
)
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
for line in file:
tgi_models.add(line.strip())
# Cache the set for future use
tgi_models_cache = tgi_models
# If not, read the file and populate the cache
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_conversational_models.txt")
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_conversational_models.txt",
)
conv_models = set()
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
for line in file:
conv_models.add(line.strip())
# Cache the set for future use
conv_models_cache = conv_models
conv_models_cache = conv_models
return tgi_models, conv_models
except:
return set(), set()
def get_hf_task_for_model(model):
# read text file, cast it to set
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models:
@ -134,9 +173,10 @@ def get_hf_task_for_model(model):
elif "roneneldan/TinyStories" in model:
return None
else:
return "text-generation-inference" # default to tgi
return "text-generation-inference" # default to tgi
class Huggingface(BaseLLM):
class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
@ -148,65 +188,93 @@ class Huggingface(BaseLLM):
"content-type": "application/json",
}
if api_key and headers is None:
default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
default_headers[
"Authorization"
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = default_headers
elif headers:
headers=headers
else:
headers = headers
else:
headers = default_headers
return headers
def convert_to_model_response_object(self,
completion_response,
model_response,
task,
optional_params,
encoding,
input_text,
model):
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
def convert_to_model_response_object(
self,
completion_response,
model_response,
task,
optional_params,
encoding,
input_text,
model,
):
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"] # type: ignore
elif task == "text-generation-inference":
if (not isinstance(completion_response, list)
] = completion_response[
"generated_text"
] # type: ignore
elif task == "text-generation-inference":
if (
not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]):
raise HuggingfaceError(status_code=422, message=f"response is not in expected format - {completion_response}")
or "generated_text" not in completion_response[0]
):
raise HuggingfaceError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
)
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = output_parser(completion_response[0]["generated_text"])
## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser(
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]._logprob = sum_logprob
if "best_of" in optional_params and optional_params["best_of"] > 1:
if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]:
if "best_of" in optional_params and optional_params["best_of"] > 1:
if (
"details" in completion_response[0]
and "best_of_sequences" in completion_response[0]["details"]
):
choices_list = []
for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]):
for idx, item in enumerate(
completion_response[0]["details"]["best_of_sequences"]
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob)
else:
if len(item["generated_text"]) > 0:
message_obj = Message(
content=output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
else:
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = output_parser(completion_response[0]["generated_text"])
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser(
completion_response[0]["generated_text"]
)
## CALCULATING USAGE
prompt_tokens = 0
try:
@ -221,12 +289,14 @@ class Huggingface(BaseLLM):
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here
except:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
else:
completion_tokens = 0
model_response["created"] = int(time.time())
@ -234,13 +304,14 @@ class Huggingface(BaseLLM):
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
model_response._hidden_params["original_response"] = completion_response
return model_response
def completion(self,
def completion(
self,
model: str,
messages: list,
api_base: Optional[str],
@ -276,9 +347,11 @@ class Huggingface(BaseLLM):
completion_url = f"https://api-inference.huggingface.co/models/{model}"
## Load Config
config=litellm.HuggingfaceConfig.get_config()
config = litellm.HuggingfaceConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
### MAP INPUT PARAMS
@ -298,11 +371,11 @@ class Huggingface(BaseLLM):
generated_responses.append(message["content"])
data = {
"inputs": {
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses,
},
"parameters": inference_params
"parameters": inference_params,
}
input_text = "".join(message["content"] for message in messages)
elif task == "text-generation-inference":
@ -311,29 +384,39 @@ class Huggingface(BaseLLM):
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
data = {
"inputs": prompt,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
}
input_text = prompt
else:
# Non TGI and Conversational llms
# We need this branch, it removes 'details' and 'return_full_text' from params
# We need this branch, it removes 'details' and 'return_full_text' from params
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
@ -346,52 +429,68 @@ class Huggingface(BaseLLM):
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
}
input_text = prompt
## LOGGING
logging_obj.pre_call(
input=input_text,
api_key=api_key,
additional_args={"complete_input_dict": data, "task": task, "headers": headers, "api_base": completion_url, "acompletion": acompletion},
)
input=input_text,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"task": task,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL
if acompletion is True:
### ASYNC STREAMING
if acompletion is True:
### ASYNC STREAMING
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
else:
### ASYNC COMPLETION
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"]
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
)
return response.iter_lines()
### SYNC COMPLETION
else:
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data)
completion_url, headers=headers, data=json.dumps(data)
)
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
is_streamed = False
if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream":
is_streamed = False
if (
response.__dict__["headers"].get("Content-Type", "")
== "text/event-stream"
):
is_streamed = True
# iterate over the complete streamed response, and return the final answer
if is_streamed:
streamed_response = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="huggingface", logging_obj=logging_obj)
streamed_response = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
content = ""
for chunk in streamed_response:
for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"]
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
completion_response: List[Dict[str, Any]] = [
{"generated_text": content}
]
## LOGGING
logging_obj.post_call(
input=input_text,
@ -399,7 +498,7 @@ class Huggingface(BaseLLM):
original_response=completion_response,
additional_args={"complete_input_dict": data, "task": task},
)
else:
else:
## LOGGING
logging_obj.post_call(
input=input_text,
@ -410,15 +509,20 @@ class Huggingface(BaseLLM):
## RESPONSE OBJECT
try:
completion_response = response.json()
if isinstance(completion_response, dict):
if isinstance(completion_response, dict):
completion_response = [completion_response]
except:
import traceback
raise HuggingfaceError(
message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}", status_code=response.status_code
message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}",
status_code=response.status_code,
)
print_verbose(f"response: {completion_response}")
if isinstance(completion_response, dict) and "error" in completion_response:
if (
isinstance(completion_response, dict)
and "error" in completion_response
):
print_verbose(f"completion error: {completion_response['error']}")
print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError(
@ -432,75 +536,98 @@ class Huggingface(BaseLLM):
optional_params=optional_params,
encoding=encoding,
input_text=input_text,
model=model
model=model,
)
except HuggingfaceError as e:
except HuggingfaceError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
except Exception as e:
if exception_mapping_worked:
raise e
else:
else:
import traceback
raise HuggingfaceError(status_code=500, message=traceback.format_exc())
async def acompletion(self,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
task: str,
encoding: Any,
input_text: str,
model: str,
optional_params: dict):
response = None
try:
async def acompletion(
self,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
task: str,
encoding: Any,
input_text: str,
model: str,
optional_params: dict,
):
response = None
try:
async with httpx.AsyncClient() as client:
response = await client.post(url=api_base, json=data, headers=headers, timeout=None)
response = await client.post(
url=api_base, json=data, headers=headers, timeout=None
)
response_json = response.json()
if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
raise HuggingfaceError(
status_code=response.status_code,
message=response.text,
request=response.request,
response=response,
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(completion_response=response_json,
model_response=model_response,
task=task,
encoding=encoding,
input_text=input_text,
model=model,
optional_params=optional_params)
except Exception as e:
if isinstance(e,httpx.TimeoutException):
return self.convert_to_model_response_object(
completion_response=response_json,
model_response=model_response,
task=task,
encoding=encoding,
input_text=input_text,
model=model,
optional_params=optional_params,
)
except Exception as e:
if isinstance(e, httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
elif response is not None and hasattr(response, "text"):
raise HuggingfaceError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
elif response is not None and hasattr(response, "text"):
raise HuggingfaceError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
)
else:
raise HuggingfaceError(status_code=500, message=f"{str(e)}")
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str):
async def async_streaming(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
):
async with httpx.AsyncClient() as client:
response = client.stream(
"POST",
url=f"{api_base}",
json=data,
headers=headers
)
async with response as r:
"POST", url=f"{api_base}", json=data, headers=headers
)
async with response as r:
if r.status_code != 200:
raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming")
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
raise HuggingfaceError(
status_code=r.status_code,
message="An error occurred while streaming",
)
streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(self,
def embedding(
self,
model: str,
input: list,
api_key: Optional[str] = None,
@ -523,65 +650,70 @@ class Huggingface(BaseLLM):
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences")
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingfaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = {
"inputs": {
"source_sentence": input[0],
"sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ]
"source_sentence": input[0],
"sentences": [
"That is a happy dog",
"That is a very happy person",
"Today is a sunny day",
],
}
}
else:
data = {
"inputs": input # type: ignore
}
data = {"inputs": input} # type: ignore
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": embed_url},
)
## COMPLETION CALL
response = requests.post(
embed_url, headers=headers, data=json.dumps(data)
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings['error'])
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
output_data = []
if "similarities" in embeddings:
if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
}
)
else:
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
for idx, embedding in enumerate(embeddings):
if isinstance(embedding, float):
if isinstance(embedding, float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
"embedding": embedding, # flatten list returned from hf
}
)
elif isinstance(embedding, list) and isinstance(embedding[0], float):
@ -589,15 +721,17 @@ class Huggingface(BaseLLM):
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
"embedding": embedding, # flatten list returned from hf
}
)
else:
else:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][0] # flatten list returned from hf
"embedding": embedding[0][
0
], # flatten list returned from hf
}
)
model_response["object"] = "list"
@ -605,13 +739,10 @@ class Huggingface(BaseLLM):
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
input_tokens += len(encoding.encode(text))
model_response["usage"] = {
"prompt_tokens": input_tokens,
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
return model_response