fix: move to using pydantic obj for setting values

This commit is contained in:
Krrish Dholakia 2024-07-11 13:18:36 -07:00
parent dd1048cb35
commit 6e9f048618
30 changed files with 1018 additions and 886 deletions

View file

@ -1,16 +1,21 @@
import os, types, traceback
from enum import Enum
import json
import requests # type: ignore
import time
from typing import Callable, Optional, Any
import litellm
from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
import sys
from copy import deepcopy
import httpx # type: ignore
import io
from .prompt_templates.factory import prompt_factory, custom_prompt
import json
import os
import sys
import time
import traceback
import types
from copy import deepcopy
from enum import Enum
from typing import Any, Callable, Optional
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, get_secret
from .prompt_templates.factory import custom_prompt, prompt_factory
class SagemakerError(Exception):
@ -377,7 +382,7 @@ def completion(
if completion_output.startswith(prompt) and "<s>" in prompt:
completion_output = completion_output.replace(prompt, "", 1)
model_response["choices"][0]["message"]["content"] = completion_output
model_response.choices[0].message.content = completion_output # type: ignore
except:
raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
@ -390,8 +395,8 @@ def completion(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response["created"] = int(time.time())
model_response["model"] = model
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
@ -597,7 +602,7 @@ async def async_completion(
if completion_output.startswith(data["inputs"]) and "<s>" in data["inputs"]:
completion_output = completion_output.replace(data["inputs"], "", 1)
model_response["choices"][0]["message"]["content"] = completion_output
model_response.choices[0].message.content = completion_output # type: ignore
except:
raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
@ -610,8 +615,8 @@ async def async_completion(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response["created"] = int(time.time())
model_response["model"] = model
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
@ -741,16 +746,20 @@ def embedding(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
model_response["usage"] = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
),
)
return model_response