fix: move to using pydantic obj for setting values

This commit is contained in:
Krrish Dholakia 2024-07-11 13:18:36 -07:00
parent dd1048cb35
commit 6e9f048618
30 changed files with 1018 additions and 886 deletions

View file

@ -1,17 +1,22 @@
## Uses the huggingface text generation inference API
import os, copy, types
import json
from enum import Enum
import httpx, requests
from .base import BaseLLM
import time
import litellm
from typing import Callable, Dict, List, Any, Literal, Tuple
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.types.completion import ChatCompletionMessageToolCallParam
import copy
import enum
import json
import os
import time
import types
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
import httpx
import requests
import litellm
from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class HuggingfaceError(Exception):
@ -269,7 +274,7 @@ class Huggingface(BaseLLM):
def convert_to_model_response_object(
self,
completion_response,
model_response,
model_response: litellm.ModelResponse,
task: hf_tasks,
optional_params,
encoding,
@ -278,11 +283,9 @@ class Huggingface(BaseLLM):
):
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response["choices"][0]["message"][
"content"
] = completion_response[
model_response.choices[0].message.content = completion_response[ # type: ignore
"generated_text"
] # type: ignore
]
elif task == "text-generation-inference":
if (
not isinstance(completion_response, list)
@ -295,7 +298,7 @@ class Huggingface(BaseLLM):
)
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser(
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
@ -310,7 +313,7 @@ class Huggingface(BaseLLM):
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]._logprob = sum_logprob
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
if "best_of" in optional_params and optional_params["best_of"] > 1:
if (
"details" in completion_response[0]
@ -337,14 +340,14 @@ class Huggingface(BaseLLM):
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
model_response.choices.extend(choices_list)
elif task == "text-classification":
model_response["choices"][0]["message"]["content"] = json.dumps(
model_response.choices[0].message.content = json.dumps( # type: ignore
completion_response
)
else:
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser(
model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"]
)
## CALCULATING USAGE
@ -371,14 +374,14 @@ class Huggingface(BaseLLM):
else:
completion_tokens = 0
model_response["created"] = int(time.time())
model_response["model"] = model
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
model_response._hidden_params["original_response"] = completion_response
return model_response
@ -763,10 +766,10 @@ class Huggingface(BaseLLM):
self,
model: str,
input: list,
model_response: litellm.EmbeddingResponse,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
model_response=None,
encoding=None,
):
super().embedding()
@ -867,15 +870,21 @@ class Huggingface(BaseLLM):
], # flatten list returned from hf
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
setattr(
model_response,
"usage",
litellm.Usage(
**{
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
),
)
return model_response