fix: move to using pydantic obj for setting values

This commit is contained in:
Krrish Dholakia 2024-07-11 13:18:36 -07:00
parent dd1048cb35
commit 6e9f048618
30 changed files with 1018 additions and 886 deletions

View file

@ -1,12 +1,12 @@
repos: repos:
- repo: local - repo: local
hooks: hooks:
- id: mypy # - id: mypy
name: mypy # name: mypy
entry: python3 -m mypy --ignore-missing-imports # entry: python3 -m mypy --ignore-missing-imports
language: system # language: system
types: [python] # types: [python]
files: ^litellm/ # files: ^litellm/
- id: isort - id: isort
name: isort name: isort
entry: isort entry: isort

View file

@ -1,11 +1,16 @@
import os, types, traceback
import json import json
import os
import time # type: ignore
import traceback
import types
from enum import Enum from enum import Enum
import requests # type: ignore
import time, httpx # type: ignore
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message
import httpx
import requests # type: ignore
import litellm import litellm
from litellm.utils import Choices, Message, ModelResponse
class AI21Error(Exception): class AI21Error(Exception):
@ -185,7 +190,7 @@ def completion(
message=message_obj, message=message_obj,
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response.choices = choices_list # type: ignore
except Exception as e: except Exception as e:
raise AI21Error( raise AI21Error(
message=traceback.format_exc(), status_code=response.status_code message=traceback.format_exc(), status_code=response.status_code
@ -197,13 +202,17 @@ def completion(
encoding.encode(model_response["choices"][0]["message"].get("content")) encoding.encode(model_response["choices"][0]["message"].get("content"))
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
model_response["usage"] = { setattr(
"prompt_tokens": prompt_tokens, model_response,
"completion_tokens": completion_tokens, "usage",
"total_tokens": prompt_tokens + completion_tokens, litellm.Usage(
} prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return model_response return model_response

View file

@ -1,12 +1,15 @@
import os, types
import json import json
from enum import Enum import os
import requests # type: ignore
import time import time
import types
from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
import litellm
from litellm.utils import ModelResponse, Choices, Message, Usage
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.utils import Choices, Message, ModelResponse, Usage
class AlephAlphaError(Exception): class AlephAlphaError(Exception):
@ -275,7 +278,7 @@ def completion(
message=message_obj, message=message_obj,
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response.choices = choices_list # type: ignore
except: except:
raise AlephAlphaError( raise AlephAlphaError(
message=json.dumps(completion_response), message=json.dumps(completion_response),
@ -291,8 +294,8 @@ def completion(
) )
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -607,8 +607,8 @@ class AnthropicChatCompletion(BaseLLM):
completion_tokens = completion_response["usage"]["output_tokens"] completion_tokens = completion_response["usage"]["output_tokens"]
total_tokens = prompt_tokens + completion_tokens total_tokens = prompt_tokens + completion_tokens
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -1,15 +1,19 @@
import os, types
import json import json
from enum import Enum import os
import requests
import time import time
import types
from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx import httpx
from .base import BaseLLM import requests
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
@ -117,9 +121,9 @@ class AnthropicTextCompletion(BaseLLM):
) )
else: else:
if len(completion_response["completion"]) > 0: if len(completion_response["completion"]) > 0:
model_response["choices"][0]["message"]["content"] = ( model_response.choices[0].message.content = completion_response[ # type: ignore
completion_response["completion"] "completion"
) ]
model_response.choices[0].finish_reason = completion_response["stop_reason"] model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE ## CALCULATING USAGE
@ -130,8 +134,8 @@ class AnthropicTextCompletion(BaseLLM):
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) ##[TODO] use the anthropic tokenizer here ) ##[TODO] use the anthropic tokenizer here
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -1,9 +1,11 @@
import os
import json import json
from enum import Enum import os
import requests # type: ignore
import time import time
from enum import Enum
from typing import Callable from typing import Callable
import requests # type: ignore
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
@ -106,28 +108,32 @@ def completion(
and "data" in completion_response["model_output"] and "data" in completion_response["model_output"]
and isinstance(completion_response["model_output"]["data"], list) and isinstance(completion_response["model_output"]["data"], list)
): ):
model_response["choices"][0]["message"]["content"] = ( model_response.choices[0].message.content = completion_response[ # type: ignore
completion_response["model_output"]["data"][0] "model_output"
) ][
"data"
][
0
]
elif isinstance(completion_response["model_output"], str): elif isinstance(completion_response["model_output"], str):
model_response["choices"][0]["message"]["content"] = ( model_response.choices[0].message.content = completion_response[ # type: ignore
completion_response["model_output"] "model_output"
) ]
elif "completion" in completion_response and isinstance( elif "completion" in completion_response and isinstance(
completion_response["completion"], str completion_response["completion"], str
): ):
model_response["choices"][0]["message"]["content"] = ( model_response.choices[0].message.content = completion_response[ # type: ignore
completion_response["completion"] "completion"
) ]
elif isinstance(completion_response, list) and len(completion_response) > 0: elif isinstance(completion_response, list) and len(completion_response) > 0:
if "generated_text" not in completion_response: if "generated_text" not in completion_response:
raise BasetenError( raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}", message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code, status_code=response.status_code,
) )
model_response["choices"][0]["message"]["content"] = ( model_response.choices[0].message.content = completion_response[0][ # type: ignore
completion_response[0]["generated_text"] "generated_text"
) ]
## GETTING LOGPROBS ## GETTING LOGPROBS
if ( if (
"details" in completion_response[0] "details" in completion_response[0]
@ -139,7 +145,7 @@ def completion(
sum_logprob = 0 sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]: for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
model_response["choices"][0]["message"]._logprobs = sum_logprob model_response.choices[0].logprobs = sum_logprob
else: else:
raise BasetenError( raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}", message=f"Unable to parse response. Original response: {response.text}",
@ -152,8 +158,8 @@ def completion(
encoding.encode(model_response["choices"][0]["message"]["content"]) encoding.encode(model_response["choices"][0]["message"]["content"])
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -1122,7 +1122,7 @@ def completion(
logging_obj=logging_obj, logging_obj=logging_obj,
) )
model_response["finish_reason"] = map_finish_reason( model_response.choices[0].finish_reason = map_finish_reason(
response_body["stop_reason"] response_body["stop_reason"]
) )
_usage = litellm.Usage( _usage = litellm.Usage(
@ -1134,14 +1134,16 @@ def completion(
setattr(model_response, "usage", _usage) setattr(model_response, "usage", _usage)
else: else:
outputText = response_body["completion"] outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"] model_response.choices[0].finish_reason = response_body["stop_reason"]
elif provider == "cohere": elif provider == "cohere":
outputText = response_body["generations"][0]["text"] outputText = response_body["generations"][0]["text"]
elif provider == "meta": elif provider == "meta":
outputText = response_body["generation"] outputText = response_body["generation"]
elif provider == "mistral": elif provider == "mistral":
outputText = response_body["outputs"][0]["text"] outputText = response_body["outputs"][0]["text"]
model_response["finish_reason"] = response_body["outputs"][0]["stop_reason"] model_response.choices[0].finish_reason = response_body["outputs"][0][
"stop_reason"
]
else: # amazon titan else: # amazon titan
outputText = response_body.get("results")[0].get("outputText") outputText = response_body.get("results")[0].get("outputText")
@ -1160,7 +1162,7 @@ def completion(
and getattr(model_response.choices[0].message, "tool_calls", None) and getattr(model_response.choices[0].message, "tool_calls", None)
is None is None
): ):
model_response["choices"][0]["message"]["content"] = outputText model_response.choices[0].message.content = outputText
elif ( elif (
hasattr(model_response.choices[0], "message") hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None) and getattr(model_response.choices[0].message, "tool_calls", None)
@ -1199,8 +1201,8 @@ def completion(
) )
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
model_response._hidden_params["region_name"] = client.meta.region_name model_response._hidden_params["region_name"] = client.meta.region_name
print_verbose(f"model_response._hidden_params: {model_response._hidden_params}") print_verbose(f"model_response._hidden_params: {model_response._hidden_params}")
@ -1323,9 +1325,9 @@ def _embedding_func_single(
def embedding( def embedding(
model: str, model: str,
input: Union[list, str], input: Union[list, str],
model_response: litellm.EmbeddingResponse,
api_key: Optional[str] = None, api_key: Optional[str] = None,
logging_obj=None, logging_obj=None,
model_response=None,
optional_params=None, optional_params=None,
encoding=None, encoding=None,
): ):
@ -1391,9 +1393,9 @@ def embedding(
"embedding": embedding, "embedding": embedding,
} }
) )
model_response["object"] = "list" model_response.object = "list"
model_response["data"] = embedding_response model_response.data = embedding_response
model_response["model"] = model model_response.model = model
input_tokens = 0 input_tokens = 0
input_str = "".join(input) input_str = "".join(input)

View file

@ -521,7 +521,7 @@ class BedrockLLM(BaseLLM):
outputText = completion_response["text"] # type: ignore outputText = completion_response["text"] # type: ignore
elif "generations" in completion_response: elif "generations" in completion_response:
outputText = completion_response["generations"][0]["text"] outputText = completion_response["generations"][0]["text"]
model_response["finish_reason"] = map_finish_reason( model_response.choices[0].finish_reason = map_finish_reason(
completion_response["generations"][0]["finish_reason"] completion_response["generations"][0]["finish_reason"]
) )
elif provider == "anthropic": elif provider == "anthropic":
@ -625,7 +625,7 @@ class BedrockLLM(BaseLLM):
logging_obj=logging_obj, logging_obj=logging_obj,
) )
model_response["finish_reason"] = map_finish_reason( model_response.choices[0].finish_reason = map_finish_reason(
completion_response.get("stop_reason", "") completion_response.get("stop_reason", "")
) )
_usage = litellm.Usage( _usage = litellm.Usage(
@ -638,7 +638,9 @@ class BedrockLLM(BaseLLM):
else: else:
outputText = completion_response["completion"] outputText = completion_response["completion"]
model_response["finish_reason"] = completion_response["stop_reason"] model_response.choices[0].finish_reason = completion_response[
"stop_reason"
]
elif provider == "ai21": elif provider == "ai21":
outputText = ( outputText = (
completion_response.get("completions")[0].get("data").get("text") completion_response.get("completions")[0].get("data").get("text")
@ -647,9 +649,9 @@ class BedrockLLM(BaseLLM):
outputText = completion_response["generation"] outputText = completion_response["generation"]
elif provider == "mistral": elif provider == "mistral":
outputText = completion_response["outputs"][0]["text"] outputText = completion_response["outputs"][0]["text"]
model_response["finish_reason"] = completion_response["outputs"][0][ model_response.choices[0].finish_reason = completion_response[
"stop_reason" "outputs"
] ][0]["stop_reason"]
else: # amazon titan else: # amazon titan
outputText = completion_response.get("results")[0].get("outputText") outputText = completion_response.get("results")[0].get("outputText")
except Exception as e: except Exception as e:
@ -667,7 +669,7 @@ class BedrockLLM(BaseLLM):
and getattr(model_response.choices[0].message, "tool_calls", None) and getattr(model_response.choices[0].message, "tool_calls", None)
is None is None
): ):
model_response["choices"][0]["message"]["content"] = outputText model_response.choices[0].message.content = outputText
elif ( elif (
hasattr(model_response.choices[0], "message") hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None) and getattr(model_response.choices[0].message, "tool_calls", None)
@ -723,8 +725,8 @@ class BedrockLLM(BaseLLM):
) )
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -1446,8 +1448,8 @@ class BedrockConverseLLM(BaseLLM):
message=litellm.Message(**chat_completion_message), message=litellm.Message(**chat_completion_message),
) )
] ]
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=input_tokens, prompt_tokens=input_tokens,
completion_tokens=output_tokens, completion_tokens=output_tokens,

View file

@ -1,13 +1,18 @@
import os, types, traceback
import json import json
import requests import os
import time import time
import traceback
import types
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, Choices, Message, CustomStreamWrapper
import litellm
import httpx import httpx
import requests
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .prompt_templates.factory import prompt_factory, custom_prompt from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, prompt_factory
class ClarifaiError(Exception): class ClarifaiError(Exception):
@ -87,7 +92,14 @@ def completions_to_model(payload):
def process_response( def process_response(
model, prompt, response, model_response, api_key, data, encoding, logging_obj model,
prompt,
response,
model_response: litellm.ModelResponse,
api_key,
data,
encoding,
logging_obj,
): ):
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
@ -116,7 +128,7 @@ def process_response(
message=message_obj, message=message_obj,
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response.choices = choices_list
except Exception as e: except Exception as e:
raise ClarifaiError( raise ClarifaiError(
@ -128,11 +140,15 @@ def process_response(
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content")) encoding.encode(model_response["choices"][0]["message"].get("content"))
) )
model_response["model"] = model model_response.model = model
model_response["usage"] = Usage( setattr(
model_response,
"usage",
Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
),
) )
return model_response return model_response
@ -202,7 +218,7 @@ async def async_completion(
message=message_obj, message=message_obj,
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response.choices = choices_list # type: ignore
except Exception as e: except Exception as e:
raise ClarifaiError( raise ClarifaiError(
@ -214,11 +230,15 @@ async def async_completion(
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content")) encoding.encode(model_response["choices"][0]["message"].get("content"))
) )
model_response["model"] = model model_response.model = model
model_response["usage"] = Usage( setattr(
model_response,
"usage",
Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
),
) )
return model_response return model_response

View file

@ -1,13 +1,17 @@
import os, types
import json import json
from enum import Enum import os
import requests # type: ignore
import time import time
import types
from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
import litellm
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
from .prompt_templates.factory import custom_prompt, prompt_factory
class CloudflareError(Exception): class CloudflareError(Exception):
@ -147,9 +151,9 @@ def completion(
) )
completion_response = response.json() completion_response = response.json()
model_response["choices"][0]["message"]["content"] = completion_response[ model_response.choices[0].message.content = completion_response["result"][ # type: ignore
"result" "response"
]["response"] ]
## CALCULATING USAGE ## CALCULATING USAGE
print_verbose( print_verbose(
@ -160,8 +164,8 @@ def completion(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = "cloudflare/" + model model_response.model = "cloudflare/" + model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -1,12 +1,16 @@
import os, types
import json import json
import os
import time
import traceback
import types
from enum import Enum from enum import Enum
import requests # type: ignore
import time, traceback
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.utils import Choices, Message, ModelResponse, Usage
class CohereError(Exception): class CohereError(Exception):
@ -117,7 +121,7 @@ class CohereConfig:
def validate_environment(api_key): def validate_environment(api_key):
headers = { headers = {
"Request-Source":"unspecified:litellm", "Request-Source": "unspecified:litellm",
"accept": "application/json", "accept": "application/json",
"content-type": "application/json", "content-type": "application/json",
} }
@ -219,7 +223,7 @@ def completion(
message=message_obj, message=message_obj,
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response.choices = choices_list # type: ignore
except Exception as e: except Exception as e:
raise CohereError( raise CohereError(
message=response.text, status_code=response.status_code message=response.text, status_code=response.status_code
@ -231,8 +235,8 @@ def completion(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -245,9 +249,9 @@ def completion(
def embedding( def embedding(
model: str, model: str,
input: list, input: list,
model_response: litellm.EmbeddingResponse,
api_key: Optional[str] = None, api_key: Optional[str] = None,
logging_obj=None, logging_obj=None,
model_response=None,
encoding=None, encoding=None,
optional_params=None, optional_params=None,
): ):
@ -294,14 +298,18 @@ def embedding(
output_data.append( output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding} {"object": "embedding", "index": idx, "embedding": embedding}
) )
model_response["object"] = "list" model_response.object = "list"
model_response["data"] = output_data model_response.data = output_data
model_response["model"] = model model_response.model = model
input_tokens = 0 input_tokens = 0
for text in input: for text in input:
input_tokens += len(encoding.encode(text)) input_tokens += len(encoding.encode(text))
model_response["usage"] = Usage( setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
),
) )
return model_response return model_response

View file

@ -305,8 +305,8 @@ def completion(
prompt_tokens = billed_units.get("input_tokens", 0) prompt_tokens = billed_units.get("input_tokens", 0)
completion_tokens = billed_units.get("output_tokens", 0) completion_tokens = billed_units.get("output_tokens", 0)
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -1,26 +1,26 @@
# What is this? # What is this?
## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request ## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
from functools import partial import copy
import os, types
import json import json
from enum import Enum import os
import requests, copy # type: ignore
import time import time
from typing import Callable, Optional, List, Union, Tuple, Literal import types
from litellm.utils import ( from enum import Enum
ModelResponse, from functools import partial
Usage, from typing import Callable, List, Literal, Optional, Tuple, Union
CustomStreamWrapper,
EmbeddingResponse,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.types.utils import ProviderField from litellm.types.utils import ProviderField
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class DatabricksError(Exception): class DatabricksError(Exception):
@ -354,8 +354,8 @@ class DatabricksChatCompletion(BaseLLM):
completion_tokens = completion_response["usage"]["output_tokens"] completion_tokens = completion_response["usage"]["output_tokens"]
total_tokens = prompt_tokens + completion_tokens total_tokens = prompt_tokens + completion_tokens
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -1,7 +1,7 @@
#################################### # ####################################
######### DEPRECATED FILE ########## # ######### DEPRECATED FILE ##########
#################################### # ####################################
# logic moved to `vertex_httpx.py` # # # logic moved to `vertex_httpx.py` #
import copy import copy
import time import time
@ -92,332 +92,332 @@ class GeminiConfig:
} }
class TextStreamer: # class TextStreamer:
""" # """
A class designed to return an async stream from AsyncGenerateContentResponse object. # A class designed to return an async stream from AsyncGenerateContentResponse object.
""" # """
def __init__(self, response): # def __init__(self, response):
self.response = response # self.response = response
self._aiter = self.response.__aiter__() # self._aiter = self.response.__aiter__()
async def __aiter__(self): # async def __aiter__(self):
while True: # while True:
try: # try:
# This will manually advance the async iterator. # # This will manually advance the async iterator.
# In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception # # In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception
next_object = await self._aiter.__anext__() # next_object = await self._aiter.__anext__()
yield next_object # yield next_object
except StopAsyncIteration: # except StopAsyncIteration:
# After getting all items from the async iterator, stop iterating # # After getting all items from the async iterator, stop iterating
break # break
def supports_system_instruction(): # def supports_system_instruction():
import google.generativeai as genai # import google.generativeai as genai
gemini_pkg_version = Version(genai.__version__) # gemini_pkg_version = Version(genai.__version__)
return gemini_pkg_version >= Version("0.5.0") # return gemini_pkg_version >= Version("0.5.0")
def completion( # def completion(
model: str, # model: str,
messages: list, # messages: list,
model_response: ModelResponse, # model_response: ModelResponse,
print_verbose: Callable, # print_verbose: Callable,
api_key, # api_key,
encoding, # encoding,
logging_obj, # logging_obj,
custom_prompt_dict: dict, # custom_prompt_dict: dict,
acompletion: bool = False, # acompletion: bool = False,
optional_params=None, # optional_params=None,
litellm_params=None, # litellm_params=None,
logger_fn=None, # logger_fn=None,
): # ):
try: # try:
import google.generativeai as genai # type: ignore # import google.generativeai as genai # type: ignore
except: # except:
raise Exception( # raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai" # "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
) # )
genai.configure(api_key=api_key) # genai.configure(api_key=api_key)
system_prompt = "" # system_prompt = ""
if model in custom_prompt_dict: # if model in custom_prompt_dict:
# check if the model has a registered custom prompt # # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] # model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( # prompt = custom_prompt(
role_dict=model_prompt_details["roles"], # role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], # initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], # final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages, # messages=messages,
) # )
else: # else:
system_prompt, messages = get_system_prompt(messages=messages) # system_prompt, messages = get_system_prompt(messages=messages)
prompt = prompt_factory( # prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="gemini" # model=model, messages=messages, custom_llm_provider="gemini"
) # )
## Load Config # ## Load Config
inference_params = copy.deepcopy(optional_params) # inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", None) # stream = inference_params.pop("stream", None)
# Handle safety settings # # Handle safety settings
safety_settings_param = inference_params.pop("safety_settings", None) # safety_settings_param = inference_params.pop("safety_settings", None)
safety_settings = None # safety_settings = None
if safety_settings_param: # if safety_settings_param:
safety_settings = [ # safety_settings = [
genai.types.SafetySettingDict(x) for x in safety_settings_param # genai.types.SafetySettingDict(x) for x in safety_settings_param
] # ]
config = litellm.GeminiConfig.get_config() # config = litellm.GeminiConfig.get_config()
for k, v in config.items(): # for k, v in config.items():
if ( # if (
k not in inference_params # k not in inference_params
): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in # ): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v # inference_params[k] = v
## LOGGING # ## LOGGING
logging_obj.pre_call( # logging_obj.pre_call(
input=prompt, # input=prompt,
api_key="", # api_key="",
additional_args={ # additional_args={
"complete_input_dict": { # "complete_input_dict": {
"inference_params": inference_params, # "inference_params": inference_params,
"system_prompt": system_prompt, # "system_prompt": system_prompt,
} # }
}, # },
) # )
## COMPLETION CALL # ## COMPLETION CALL
try: # try:
_params = {"model_name": "models/{}".format(model)} # _params = {"model_name": "models/{}".format(model)}
_system_instruction = supports_system_instruction() # _system_instruction = supports_system_instruction()
if _system_instruction and len(system_prompt) > 0: # if _system_instruction and len(system_prompt) > 0:
_params["system_instruction"] = system_prompt # _params["system_instruction"] = system_prompt
_model = genai.GenerativeModel(**_params) # _model = genai.GenerativeModel(**_params)
if stream is True: # if stream is True:
if acompletion is True: # if acompletion is True:
async def async_streaming(): # async def async_streaming():
try: # try:
response = await _model.generate_content_async( # response = await _model.generate_content_async(
contents=prompt, # contents=prompt,
generation_config=genai.types.GenerationConfig( # generation_config=genai.types.GenerationConfig(
**inference_params # **inference_params
), # ),
safety_settings=safety_settings, # safety_settings=safety_settings,
stream=True, # stream=True,
) # )
response = litellm.CustomStreamWrapper( # response = litellm.CustomStreamWrapper(
TextStreamer(response), # TextStreamer(response),
model, # model,
custom_llm_provider="gemini", # custom_llm_provider="gemini",
logging_obj=logging_obj, # logging_obj=logging_obj,
) # )
return response # return response
except Exception as e: # except Exception as e:
raise GeminiError(status_code=500, message=str(e)) # raise GeminiError(status_code=500, message=str(e))
return async_streaming() # return async_streaming()
response = _model.generate_content( # response = _model.generate_content(
contents=prompt, # contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params), # generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings, # safety_settings=safety_settings,
stream=True, # stream=True,
) # )
return response # return response
elif acompletion == True: # elif acompletion == True:
return async_completion( # return async_completion(
_model=_model, # _model=_model,
model=model, # model=model,
prompt=prompt, # prompt=prompt,
inference_params=inference_params, # inference_params=inference_params,
safety_settings=safety_settings, # safety_settings=safety_settings,
logging_obj=logging_obj, # logging_obj=logging_obj,
print_verbose=print_verbose, # print_verbose=print_verbose,
model_response=model_response, # model_response=model_response,
messages=messages, # messages=messages,
encoding=encoding, # encoding=encoding,
) # )
else: # else:
params = { # params = {
"contents": prompt, # "contents": prompt,
"generation_config": genai.types.GenerationConfig(**inference_params), # "generation_config": genai.types.GenerationConfig(**inference_params),
"safety_settings": safety_settings, # "safety_settings": safety_settings,
} # }
response = _model.generate_content(**params) # response = _model.generate_content(**params)
except Exception as e: # except Exception as e:
raise GeminiError( # raise GeminiError(
message=str(e), # message=str(e),
status_code=500, # status_code=500,
) # )
## LOGGING # ## LOGGING
logging_obj.post_call( # logging_obj.post_call(
input=prompt, # input=prompt,
api_key="", # api_key="",
original_response=response, # original_response=response,
additional_args={"complete_input_dict": {}}, # additional_args={"complete_input_dict": {}},
) # )
print_verbose(f"raw model_response: {response}") # print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT # ## RESPONSE OBJECT
completion_response = response # completion_response = response
try: # try:
choices_list = [] # choices_list = []
for idx, item in enumerate(completion_response.candidates): # for idx, item in enumerate(completion_response.candidates):
if len(item.content.parts) > 0: # if len(item.content.parts) > 0:
message_obj = Message(content=item.content.parts[0].text) # message_obj = Message(content=item.content.parts[0].text)
else: # else:
message_obj = Message(content=None) # message_obj = Message(content=None)
choice_obj = Choices(index=idx, message=message_obj) # choice_obj = Choices(index=idx, message=message_obj)
choices_list.append(choice_obj) # choices_list.append(choice_obj)
model_response["choices"] = choices_list # model_response.choices = choices_list
except Exception as e: # except Exception as e:
verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e))) # verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
verbose_logger.debug(traceback.format_exc()) # verbose_logger.debug(traceback.format_exc())
raise GeminiError( # raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code # message=traceback.format_exc(), status_code=response.status_code
) # )
try: # try:
completion_response = model_response["choices"][0]["message"].get("content") # completion_response = model_response["choices"][0]["message"].get("content")
if completion_response is None: # if completion_response is None:
raise Exception # raise Exception
except: # except:
original_response = f"response: {response}" # original_response = f"response: {response}"
if hasattr(response, "candidates"): # if hasattr(response, "candidates"):
original_response = f"response: {response.candidates}" # original_response = f"response: {response.candidates}"
if "SAFETY" in original_response: # if "SAFETY" in original_response:
original_response += ( # original_response += (
"\nThe candidate content was flagged for safety reasons." # "\nThe candidate content was flagged for safety reasons."
) # )
elif "RECITATION" in original_response: # elif "RECITATION" in original_response:
original_response += ( # original_response += (
"\nThe candidate content was flagged for recitation reasons." # "\nThe candidate content was flagged for recitation reasons."
) # )
raise GeminiError( # raise GeminiError(
status_code=400, # status_code=400,
message=f"No response received. Original response - {original_response}", # message=f"No response received. Original response - {original_response}",
) # )
## CALCULATING USAGE # ## CALCULATING USAGE
prompt_str = "" # prompt_str = ""
for m in messages: # for m in messages:
if isinstance(m["content"], str): # if isinstance(m["content"], str):
prompt_str += m["content"] # prompt_str += m["content"]
elif isinstance(m["content"], list): # elif isinstance(m["content"], list):
for content in m["content"]: # for content in m["content"]:
if content["type"] == "text": # if content["type"] == "text":
prompt_str += content["text"] # prompt_str += content["text"]
prompt_tokens = len(encoding.encode(prompt_str)) # prompt_tokens = len(encoding.encode(prompt_str))
completion_tokens = len( # completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) # encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) # )
model_response["created"] = int(time.time()) # model_response.created = int(time.time())
model_response["model"] = "gemini/" + model # model_response.model = "gemini/" + model
usage = Usage( # usage = Usage(
prompt_tokens=prompt_tokens, # prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, # completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, # total_tokens=prompt_tokens + completion_tokens,
) # )
setattr(model_response, "usage", usage) # setattr(model_response, "usage", usage)
return model_response # return model_response
async def async_completion( # async def async_completion(
_model, # _model,
model, # model,
prompt, # prompt,
inference_params, # inference_params,
safety_settings, # safety_settings,
logging_obj, # logging_obj,
print_verbose, # print_verbose,
model_response, # model_response,
messages, # messages,
encoding, # encoding,
): # ):
import google.generativeai as genai # type: ignore # import google.generativeai as genai # type: ignore
response = await _model.generate_content_async( # response = await _model.generate_content_async(
contents=prompt, # contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params), # generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings, # safety_settings=safety_settings,
) # )
## LOGGING # ## LOGGING
logging_obj.post_call( # logging_obj.post_call(
input=prompt, # input=prompt,
api_key="", # api_key="",
original_response=response, # original_response=response,
additional_args={"complete_input_dict": {}}, # additional_args={"complete_input_dict": {}},
) # )
print_verbose(f"raw model_response: {response}") # print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT # ## RESPONSE OBJECT
completion_response = response # completion_response = response
try: # try:
choices_list = [] # choices_list = []
for idx, item in enumerate(completion_response.candidates): # for idx, item in enumerate(completion_response.candidates):
if len(item.content.parts) > 0: # if len(item.content.parts) > 0:
message_obj = Message(content=item.content.parts[0].text) # message_obj = Message(content=item.content.parts[0].text)
else: # else:
message_obj = Message(content=None) # message_obj = Message(content=None)
choice_obj = Choices(index=idx, message=message_obj) # choice_obj = Choices(index=idx, message=message_obj)
choices_list.append(choice_obj) # choices_list.append(choice_obj)
model_response["choices"] = choices_list # model_response["choices"] = choices_list
except Exception as e: # except Exception as e:
verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e))) # verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
verbose_logger.debug(traceback.format_exc()) # verbose_logger.debug(traceback.format_exc())
raise GeminiError( # raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code # message=traceback.format_exc(), status_code=response.status_code
) # )
try: # try:
completion_response = model_response["choices"][0]["message"].get("content") # completion_response = model_response["choices"][0]["message"].get("content")
if completion_response is None: # if completion_response is None:
raise Exception # raise Exception
except: # except:
original_response = f"response: {response}" # original_response = f"response: {response}"
if hasattr(response, "candidates"): # if hasattr(response, "candidates"):
original_response = f"response: {response.candidates}" # original_response = f"response: {response.candidates}"
if "SAFETY" in original_response: # if "SAFETY" in original_response:
original_response += ( # original_response += (
"\nThe candidate content was flagged for safety reasons." # "\nThe candidate content was flagged for safety reasons."
) # )
elif "RECITATION" in original_response: # elif "RECITATION" in original_response:
original_response += ( # original_response += (
"\nThe candidate content was flagged for recitation reasons." # "\nThe candidate content was flagged for recitation reasons."
) # )
raise GeminiError( # raise GeminiError(
status_code=400, # status_code=400,
message=f"No response received. Original response - {original_response}", # message=f"No response received. Original response - {original_response}",
) # )
## CALCULATING USAGE # ## CALCULATING USAGE
prompt_str = "" # prompt_str = ""
for m in messages: # for m in messages:
if isinstance(m["content"], str): # if isinstance(m["content"], str):
prompt_str += m["content"] # prompt_str += m["content"]
elif isinstance(m["content"], list): # elif isinstance(m["content"], list):
for content in m["content"]: # for content in m["content"]:
if content["type"] == "text": # if content["type"] == "text":
prompt_str += content["text"] # prompt_str += content["text"]
prompt_tokens = len(encoding.encode(prompt_str)) # prompt_tokens = len(encoding.encode(prompt_str))
completion_tokens = len( # completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) # encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) # )
model_response["created"] = int(time.time()) # model_response["created"] = int(time.time())
model_response["model"] = "gemini/" + model # model_response["model"] = "gemini/" + model
usage = Usage( # usage = Usage(
prompt_tokens=prompt_tokens, # prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, # completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, # total_tokens=prompt_tokens + completion_tokens,
) # )
model_response.usage = usage # model_response.usage = usage
return model_response # return model_response
def embedding(): # def embedding():
# logic for parsing in - calling - parsing out model embedding calls # # logic for parsing in - calling - parsing out model embedding calls
pass # pass

View file

@ -1,17 +1,22 @@
## Uses the huggingface text generation inference API ## Uses the huggingface text generation inference API
import os, copy, types import copy
import json
from enum import Enum
import httpx, requests
from .base import BaseLLM
import time
import litellm
from typing import Callable, Dict, List, Any, Literal, Tuple
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.types.completion import ChatCompletionMessageToolCallParam
import enum import enum
import json
import os
import time
import types
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
import httpx
import requests
import litellm
from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
@ -269,7 +274,7 @@ class Huggingface(BaseLLM):
def convert_to_model_response_object( def convert_to_model_response_object(
self, self,
completion_response, completion_response,
model_response, model_response: litellm.ModelResponse,
task: hf_tasks, task: hf_tasks,
optional_params, optional_params,
encoding, encoding,
@ -278,11 +283,9 @@ class Huggingface(BaseLLM):
): ):
if task == "conversational": if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore if len(completion_response["generated_text"]) > 0: # type: ignore
model_response["choices"][0]["message"][ model_response.choices[0].message.content = completion_response[ # type: ignore
"content"
] = completion_response[
"generated_text" "generated_text"
] # type: ignore ]
elif task == "text-generation-inference": elif task == "text-generation-inference":
if ( if (
not isinstance(completion_response, list) not isinstance(completion_response, list)
@ -295,7 +298,7 @@ class Huggingface(BaseLLM):
) )
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser( model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"] completion_response[0]["generated_text"]
) )
## GETTING LOGPROBS + FINISH REASON ## GETTING LOGPROBS + FINISH REASON
@ -310,7 +313,7 @@ class Huggingface(BaseLLM):
for token in completion_response[0]["details"]["tokens"]: for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] != None: if token["logprob"] != None:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
model_response["choices"][0]["message"]._logprob = sum_logprob setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
if "best_of" in optional_params and optional_params["best_of"] > 1: if "best_of" in optional_params and optional_params["best_of"] > 1:
if ( if (
"details" in completion_response[0] "details" in completion_response[0]
@ -337,14 +340,14 @@ class Huggingface(BaseLLM):
message=message_obj, message=message_obj,
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"].extend(choices_list) model_response.choices.extend(choices_list)
elif task == "text-classification": elif task == "text-classification":
model_response["choices"][0]["message"]["content"] = json.dumps( model_response.choices[0].message.content = json.dumps( # type: ignore
completion_response completion_response
) )
else: else:
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser( model_response.choices[0].message.content = output_parser( # type: ignore
completion_response[0]["generated_text"] completion_response[0]["generated_text"]
) )
## CALCULATING USAGE ## CALCULATING USAGE
@ -371,14 +374,14 @@ class Huggingface(BaseLLM):
else: else:
completion_tokens = 0 completion_tokens = 0
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage)
model_response._hidden_params["original_response"] = completion_response model_response._hidden_params["original_response"] = completion_response
return model_response return model_response
@ -763,10 +766,10 @@ class Huggingface(BaseLLM):
self, self,
model: str, model: str,
input: list, input: list,
model_response: litellm.EmbeddingResponse,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
logging_obj=None, logging_obj=None,
model_response=None,
encoding=None, encoding=None,
): ):
super().embedding() super().embedding()
@ -867,15 +870,21 @@ class Huggingface(BaseLLM):
], # flatten list returned from hf ], # flatten list returned from hf
} }
) )
model_response["object"] = "list" model_response.object = "list"
model_response["data"] = output_data model_response.data = output_data
model_response["model"] = model model_response.model = model
input_tokens = 0 input_tokens = 0
for text in input: for text in input:
input_tokens += len(encoding.encode(text)) input_tokens += len(encoding.encode(text))
model_response["usage"] = { setattr(
model_response,
"usage",
litellm.Usage(
**{
"prompt_tokens": input_tokens, "prompt_tokens": input_tokens,
"total_tokens": input_tokens, "total_tokens": input_tokens,
} }
),
)
return model_response return model_response

View file

@ -1,11 +1,15 @@
import os, types
import json import json
import os
import time
import traceback
import types
from enum import Enum from enum import Enum
from typing import Callable, List, Optional
import requests # type: ignore import requests # type: ignore
import time, traceback
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
from litellm.utils import Choices, Message, ModelResponse, Usage
class MaritalkError(Exception): class MaritalkError(Exception):
@ -152,9 +156,9 @@ def completion(
else: else:
try: try:
if len(completion_response["answer"]) > 0: if len(completion_response["answer"]) > 0:
model_response["choices"][0]["message"]["content"] = ( model_response.choices[0].message.content = completion_response[ # type: ignore
completion_response["answer"] "answer"
) ]
except Exception as e: except Exception as e:
raise MaritalkError( raise MaritalkError(
message=response.text, status_code=response.status_code message=response.text, status_code=response.status_code
@ -167,8 +171,8 @@ def completion(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -1,9 +1,12 @@
import os, types
import json import json
from enum import Enum import os
import requests # type: ignore
import time import time
import types
from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
import requests # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
@ -185,7 +188,7 @@ def completion(
else: else:
try: try:
if len(completion_response["generated_text"]) > 0: if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = ( model_response.choices[0].message.content = ( # type: ignore
completion_response["generated_text"] completion_response["generated_text"]
) )
except: except:
@ -198,8 +201,8 @@ def completion(
prompt_tokens = completion_response["nb_input_tokens"] prompt_tokens = completion_response["nb_input_tokens"]
completion_tokens = completion_response["nb_generated_tokens"] completion_tokens = completion_response["nb_generated_tokens"]
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -1,13 +1,21 @@
from itertools import chain import asyncio
import requests, types, time # type: ignore import json
import json, uuid import time
import traceback import traceback
from typing import Optional, List import types
import uuid
from itertools import chain
from typing import List, Optional
import aiohttp
import httpx # type: ignore
import requests # type: ignore
import litellm import litellm
from litellm.types.utils import ProviderField
import httpx, aiohttp, asyncio # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm import verbose_logger from litellm import verbose_logger
from litellm.types.utils import ProviderField
from .prompt_templates.factory import custom_prompt, prompt_factory
class OllamaError(Exception): class OllamaError(Exception):
@ -138,7 +146,6 @@ class OllamaConfig:
) )
] ]
def get_supported_openai_params( def get_supported_openai_params(
self, self,
): ):
@ -157,7 +164,8 @@ class OllamaConfig:
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI # ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
# and convert to jpeg if necessary. # and convert to jpeg if necessary.
def _convert_image(image): def _convert_image(image):
import base64, io import base64
import io
try: try:
from PIL import Image from PIL import Image
@ -183,13 +191,13 @@ def _convert_image(image):
# ollama implementation # ollama implementation
def get_ollama_response( def get_ollama_response(
model_response: litellm.ModelResponse,
api_base="http://localhost:11434", api_base="http://localhost:11434",
model="llama2", model="llama2",
prompt="Why is the sky blue?", prompt="Why is the sky blue?",
optional_params=None, optional_params=None,
logging_obj=None, logging_obj=None,
acompletion: bool = False, acompletion: bool = False,
model_response=None,
encoding=None, encoding=None,
): ):
if api_base.endswith("/api/generate"): if api_base.endswith("/api/generate"):
@ -271,7 +279,7 @@ def get_ollama_response(
response_json = response.json() response_json = response.json()
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json": if data.get("format", "") == "json":
function_call = json.loads(response_json["response"]) function_call = json.loads(response_json["response"])
message = litellm.Message( message = litellm.Message(
@ -287,20 +295,24 @@ def get_ollama_response(
} }
], ],
) )
model_response["choices"][0]["message"] = message model_response.choices[0].message = message # type: ignore
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response.choices[0].finish_reason = "tool_calls"
else: else:
model_response["choices"][0]["message"]["content"] = response_json["response"] model_response.choices[0].message.content = response_json["response"] # type: ignore
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = "ollama/" + model model_response.model = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
completion_tokens = response_json.get( completion_tokens = response_json.get(
"eval_count", len(response_json.get("message", dict()).get("content", "")) "eval_count", len(response_json.get("message", dict()).get("content", ""))
) )
model_response["usage"] = litellm.Usage( setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
),
) )
return model_response return model_response
@ -346,8 +358,8 @@ def ollama_completion_stream(url, data, logging_obj):
], ],
) )
model_response = first_chunk model_response = first_chunk
model_response["choices"][0]["delta"] = delta model_response.choices[0].delta = delta # type: ignore
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response.choices[0].finish_reason = "tool_calls"
yield model_response yield model_response
else: else:
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
@ -401,8 +413,8 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
], ],
) )
model_response = first_chunk model_response = first_chunk
model_response["choices"][0]["delta"] = delta model_response.choices[0].delta = delta # type: ignore
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response.choices[0].finish_reason = "tool_calls"
yield model_response yield model_response
else: else:
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
@ -418,7 +430,9 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
raise e raise e
async def ollama_acompletion(url, data, model_response, encoding, logging_obj): async def ollama_acompletion(
url, data, model_response: litellm.ModelResponse, encoding, logging_obj
):
data["stream"] = False data["stream"] = False
try: try:
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
@ -442,7 +456,7 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
response_json = await resp.json() response_json = await resp.json()
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json": if data.get("format", "") == "json":
function_call = json.loads(response_json["response"]) function_call = json.loads(response_json["response"])
message = litellm.Message( message = litellm.Message(
@ -451,30 +465,34 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": { "function": {
"name": function_call.get("name", function_call.get("function", None)), "name": function_call.get(
"name", function_call.get("function", None)
),
"arguments": json.dumps(function_call["arguments"]), "arguments": json.dumps(function_call["arguments"]),
}, },
"type": "function", "type": "function",
} }
], ],
) )
model_response["choices"][0]["message"] = message model_response.choices[0].message = message # type: ignore
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response.choices[0].finish_reason = "tool_calls"
else: else:
model_response["choices"][0]["message"]["content"] = response_json[ model_response.choices[0].message.content = response_json["response"] # type: ignore
"response" model_response.created = int(time.time())
] model_response.model = "ollama/" + data["model"]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
completion_tokens = response_json.get( completion_tokens = response_json.get(
"eval_count", "eval_count",
len(response_json.get("message", dict()).get("content", "")), len(response_json.get("message", dict()).get("content", "")),
) )
model_response["usage"] = litellm.Usage( setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
),
) )
return model_response return model_response
except Exception as e: except Exception as e:
@ -491,9 +509,9 @@ async def ollama_aembeddings(
api_base: str, api_base: str,
model: str, model: str,
prompts: list, prompts: list,
model_response: litellm.EmbeddingResponse,
optional_params=None, optional_params=None,
logging_obj=None, logging_obj=None,
model_response=None,
encoding=None, encoding=None,
): ):
if api_base.endswith("/api/embeddings"): if api_base.endswith("/api/embeddings"):
@ -554,13 +572,19 @@ async def ollama_aembeddings(
input_tokens = len(encoding.encode(prompt)) input_tokens = len(encoding.encode(prompt))
total_input_tokens += input_tokens total_input_tokens += input_tokens
model_response["object"] = "list" model_response.object = "list"
model_response["data"] = output_data model_response.data = output_data
model_response["model"] = model model_response.model = model
model_response["usage"] = { setattr(
model_response,
"usage",
litellm.Usage(
**{
"prompt_tokens": total_input_tokens, "prompt_tokens": total_input_tokens,
"total_tokens": total_input_tokens, "total_tokens": total_input_tokens,
} }
),
)
return model_response return model_response

View file

@ -1,15 +1,17 @@
from itertools import chain
import requests
import types
import time
import json import json
import uuid import time
import traceback import traceback
import types
import uuid
from itertools import chain
from typing import Optional from typing import Optional
from litellm import verbose_logger
import litellm
import httpx
import aiohttp import aiohttp
import httpx
import requests
import litellm
from litellm import verbose_logger
class OllamaError(Exception): class OllamaError(Exception):
@ -195,6 +197,7 @@ class OllamaChatConfig:
# ollama implementation # ollama implementation
def get_ollama_response( def get_ollama_response(
model_response: litellm.ModelResponse,
api_base="http://localhost:11434", api_base="http://localhost:11434",
api_key: Optional[str] = None, api_key: Optional[str] = None,
model="llama2", model="llama2",
@ -202,7 +205,6 @@ def get_ollama_response(
optional_params=None, optional_params=None,
logging_obj=None, logging_obj=None,
acompletion: bool = False, acompletion: bool = False,
model_response=None,
encoding=None, encoding=None,
): ):
if api_base.endswith("/api/chat"): if api_base.endswith("/api/chat"):
@ -295,7 +297,7 @@ def get_ollama_response(
response_json = response.json() response_json = response.json()
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json": if data.get("format", "") == "json":
function_call = json.loads(response_json["message"]["content"]) function_call = json.loads(response_json["message"]["content"])
message = litellm.Message( message = litellm.Message(
@ -311,22 +313,24 @@ def get_ollama_response(
} }
], ],
) )
model_response["choices"][0]["message"] = message model_response.choices[0].message = message # type: ignore
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response.choices[0].finish_reason = "tool_calls"
else: else:
model_response["choices"][0]["message"]["content"] = response_json["message"][ model_response.choices[0].message.content = response_json["message"]["content"] # type: ignore
"content" model_response.created = int(time.time())
] model_response.model = "ollama/" + model
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
completion_tokens = response_json.get( completion_tokens = response_json.get(
"eval_count", litellm.token_counter(text=response_json["message"]["content"]) "eval_count", litellm.token_counter(text=response_json["message"]["content"])
) )
model_response["usage"] = litellm.Usage( setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
),
) )
return model_response return model_response
@ -379,8 +383,8 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
], ],
) )
model_response = first_chunk model_response = first_chunk
model_response["choices"][0]["delta"] = delta model_response.choices[0].delta = delta # type: ignore
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response.choices[0].finish_reason = "tool_calls"
yield model_response yield model_response
else: else:
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
@ -434,7 +438,9 @@ async def ollama_async_streaming(
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": { "function": {
"name": function_call.get("name", function_call.get("function", None)), "name": function_call.get(
"name", function_call.get("function", None)
),
"arguments": json.dumps(function_call["arguments"]), "arguments": json.dumps(function_call["arguments"]),
}, },
"type": "function", "type": "function",
@ -442,8 +448,8 @@ async def ollama_async_streaming(
], ],
) )
model_response = first_chunk model_response = first_chunk
model_response["choices"][0]["delta"] = delta model_response.choices[0].delta = delta # type: ignore
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response.choices[0].finish_reason = "tool_calls"
yield model_response yield model_response
else: else:
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
@ -457,7 +463,7 @@ async def ollama_acompletion(
url, url,
api_key: Optional[str], api_key: Optional[str],
data, data,
model_response, model_response: litellm.ModelResponse,
encoding, encoding,
logging_obj, logging_obj,
function_name, function_name,
@ -492,7 +498,7 @@ async def ollama_acompletion(
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json": if data.get("format", "") == "json":
function_call = json.loads(response_json["message"]["content"]) function_call = json.loads(response_json["message"]["content"])
message = litellm.Message( message = litellm.Message(
@ -510,15 +516,17 @@ async def ollama_acompletion(
} }
], ],
) )
model_response["choices"][0]["message"] = message model_response.choices[0].message = message # type: ignore
model_response["choices"][0]["finish_reason"] = "tool_calls" model_response.choices[0].finish_reason = "tool_calls"
else: else:
model_response["choices"][0]["message"]["content"] = response_json[ model_response.choices[0].message.content = response_json[ # type: ignore
"message" "message"
]["content"] ][
"content"
]
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = "ollama_chat/" + data["model"] model_response.model = "ollama_chat/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=data["messages"])) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=data["messages"])) # type: ignore
completion_tokens = response_json.get( completion_tokens = response_json.get(
"eval_count", "eval_count",
@ -526,10 +534,14 @@ async def ollama_acompletion(
text=response_json["message"]["content"], count_response_tokens=True text=response_json["message"]["content"], count_response_tokens=True
), ),
) )
model_response["usage"] = litellm.Usage( setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
),
) )
return model_response return model_response
except Exception as e: except Exception as e:

View file

@ -1,11 +1,14 @@
import os
import json import json
from enum import Enum import os
import requests # type: ignore
import time import time
from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt import requests # type: ignore
from litellm.utils import EmbeddingResponse, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, prompt_factory
class OobaboogaError(Exception): class OobaboogaError(Exception):
@ -99,17 +102,15 @@ def completion(
) )
else: else:
try: try:
model_response["choices"][0]["message"]["content"] = ( model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
completion_response["choices"][0]["message"]["content"]
)
except: except:
raise OobaboogaError( raise OobaboogaError(
message=json.dumps(completion_response), message=json.dumps(completion_response),
status_code=response.status_code, status_code=response.status_code,
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=completion_response["usage"]["prompt_tokens"], prompt_tokens=completion_response["usage"]["prompt_tokens"],
completion_tokens=completion_response["usage"]["completion_tokens"], completion_tokens=completion_response["usage"]["completion_tokens"],
@ -122,10 +123,10 @@ def completion(
def embedding( def embedding(
model: str, model: str,
input: list, input: list,
model_response: EmbeddingResponse,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
logging_obj=None, logging_obj=None,
model_response=None,
optional_params=None, optional_params=None,
encoding=None, encoding=None,
): ):
@ -166,7 +167,7 @@ def embedding(
) )
# Process response data # Process response data
model_response["data"] = [ model_response.data = [
{ {
"embedding": completion_response["data"][0]["embedding"], "embedding": completion_response["data"][0]["embedding"],
"index": 0, "index": 0,
@ -176,8 +177,12 @@ def embedding(
num_tokens = len(completion_response["data"][0]["embedding"]) num_tokens = len(completion_response["data"][0]["embedding"])
# Adding metadata to response # Adding metadata to response
model_response.usage = Usage(prompt_tokens=num_tokens, total_tokens=num_tokens) setattr(
model_response["object"] = "list" model_response,
model_response["model"] = model "usage",
Usage(prompt_tokens=num_tokens, total_tokens=num_tokens),
)
model_response.object = "list"
model_response.model = model
return model_response return model_response

View file

@ -1,12 +1,14 @@
import types
import traceback
import copy import copy
import time import time
import traceback
import types
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import httpx import httpx
import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.utils import Choices, Message, ModelResponse, Usage
class PalmError(Exception): class PalmError(Exception):
@ -164,7 +166,7 @@ def completion(
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(index=idx + 1, message=message_obj) choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response.choices = choices_list # type: ignore
except Exception as e: except Exception as e:
verbose_logger.error( verbose_logger.error(
"litellm.llms.palm.py::completion(): Exception occured - {}".format(str(e)) "litellm.llms.palm.py::completion(): Exception occured - {}".format(str(e))
@ -188,8 +190,8 @@ def completion(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = "palm/" + model model_response.model = "palm/" + model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -1,12 +1,16 @@
import os, types
import json import json
from enum import Enum import os
import requests # type: ignore
import time import time
import types
from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
import requests # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
from .prompt_templates.factory import custom_prompt, prompt_factory
class PetalsError(Exception): class PetalsError(Exception):
@ -151,8 +155,8 @@ def completion(
else: else:
try: try:
import torch import torch
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM # type: ignore from petals import AutoDistributedModelForCausalLM # type: ignore
from transformers import AutoTokenizer
except: except:
raise Exception( raise Exception(
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals" "Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
@ -189,15 +193,15 @@ def completion(
output_text = tokenizer.decode(outputs[0]) output_text = tokenizer.decode(outputs[0])
if len(output_text) > 0: if len(output_text) > 0:
model_response["choices"][0]["message"]["content"] = output_text model_response.choices[0].message.content = output_text # type: ignore
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content")) encoding.encode(model_response["choices"][0]["message"].get("content"))
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -279,7 +279,7 @@ class PredibaseChatCompletion(BaseLLM):
message=f"'generated_text' is not a key response dictionary - {completion_response}", message=f"'generated_text' is not a key response dictionary - {completion_response}",
) )
if len(completion_response["generated_text"]) > 0: if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = self.output_parser( model_response.choices[0].message.content = self.output_parser( # type: ignore
completion_response["generated_text"] completion_response["generated_text"]
) )
## GETTING LOGPROBS + FINISH REASON ## GETTING LOGPROBS + FINISH REASON
@ -294,10 +294,10 @@ class PredibaseChatCompletion(BaseLLM):
for token in completion_response["details"]["tokens"]: for token in completion_response["details"]["tokens"]:
if token["logprob"] is not None: if token["logprob"] is not None:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
model_response["choices"][0][ setattr(
"message" model_response.choices[0].message, # type: ignore
]._logprob = ( "_logprob",
sum_logprob # [TODO] move this to using the actual logprobs sum_logprob, # [TODO] move this to using the actual logprobs
) )
if "best_of" in optional_params and optional_params["best_of"] > 1: if "best_of" in optional_params and optional_params["best_of"] > 1:
if ( if (
@ -325,7 +325,7 @@ class PredibaseChatCompletion(BaseLLM):
message=message_obj, message=message_obj,
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"].extend(choices_list) model_response.choices.extend(choices_list)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = 0 prompt_tokens = 0
@ -351,8 +351,8 @@ class PredibaseChatCompletion(BaseLLM):
total_tokens = prompt_tokens + completion_tokens total_tokens = prompt_tokens + completion_tokens
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -388,7 +388,7 @@ def process_response(
## Building RESPONSE OBJECT ## Building RESPONSE OBJECT
if len(result) > 1: if len(result) > 1:
model_response["choices"][0]["message"]["content"] = result model_response.choices[0].message.content = result # type :ignore
# Calculate usage # Calculate usage
prompt_tokens = len(encoding.encode(prompt, disallowed_special=())) prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
@ -398,7 +398,7 @@ def process_response(
disallowed_special=(), disallowed_special=(),
) )
) )
model_response["model"] = "replicate/" + model model_response.model = "replicate/" + model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -498,7 +498,7 @@ def completion(
## Step1: Start Prediction: gets a prediction url ## Step1: Start Prediction: gets a prediction url
## Step2: Poll prediction url for response ## Step2: Poll prediction url for response
## Step2: is handled with and without streaming ## Step2: is handled with and without streaming
model_response["created"] = int( model_response.created = int(
time.time() time.time()
) # for pricing this must remain right before calling api ) # for pricing this must remain right before calling api

View file

@ -1,16 +1,21 @@
import os, types, traceback
from enum import Enum
import json
import requests # type: ignore
import time
from typing import Callable, Optional, Any
import litellm
from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
import sys
from copy import deepcopy
import httpx # type: ignore
import io import io
from .prompt_templates.factory import prompt_factory, custom_prompt import json
import os
import sys
import time
import traceback
import types
from copy import deepcopy
from enum import Enum
from typing import Any, Callable, Optional
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, get_secret
from .prompt_templates.factory import custom_prompt, prompt_factory
class SagemakerError(Exception): class SagemakerError(Exception):
@ -377,7 +382,7 @@ def completion(
if completion_output.startswith(prompt) and "<s>" in prompt: if completion_output.startswith(prompt) and "<s>" in prompt:
completion_output = completion_output.replace(prompt, "", 1) completion_output = completion_output.replace(prompt, "", 1)
model_response["choices"][0]["message"]["content"] = completion_output model_response.choices[0].message.content = completion_output # type: ignore
except: except:
raise SagemakerError( raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
@ -390,8 +395,8 @@ def completion(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -597,7 +602,7 @@ async def async_completion(
if completion_output.startswith(data["inputs"]) and "<s>" in data["inputs"]: if completion_output.startswith(data["inputs"]) and "<s>" in data["inputs"]:
completion_output = completion_output.replace(data["inputs"], "", 1) completion_output = completion_output.replace(data["inputs"], "", 1)
model_response["choices"][0]["message"]["content"] = completion_output model_response.choices[0].message.content = completion_output # type: ignore
except: except:
raise SagemakerError( raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
@ -610,8 +615,8 @@ async def async_completion(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -741,16 +746,20 @@ def embedding(
{"object": "embedding", "index": idx, "embedding": embedding} {"object": "embedding", "index": idx, "embedding": embedding}
) )
model_response["object"] = "list" model_response.object = "list"
model_response["data"] = output_data model_response.data = output_data
model_response["model"] = model model_response.model = model
input_tokens = 0 input_tokens = 0
for text in input: for text in input:
input_tokens += len(encoding.encode(text)) input_tokens += len(encoding.encode(text))
model_response["usage"] = Usage( setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
),
) )
return model_response return model_response

View file

@ -3,233 +3,233 @@ Deprecated. We now do together ai calls via the openai client.
Reference: https://docs.together.ai/docs/openai-api-compatibility Reference: https://docs.together.ai/docs/openai-api-compatibility
""" """
import os, types # import os, types
import json # import json
from enum import Enum # from enum import Enum
import requests # type: ignore # import requests # type: ignore
import time # import time
from typing import Callable, Optional # from typing import Callable, Optional
import litellm # import litellm
import httpx # type: ignore # import httpx # type: ignore
from litellm.utils import ModelResponse, Usage # from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt # from .prompt_templates.factory import prompt_factory, custom_prompt
class TogetherAIError(Exception): # class TogetherAIError(Exception):
def __init__(self, status_code, message): # def __init__(self, status_code, message):
self.status_code = status_code # self.status_code = status_code
self.message = message # self.message = message
self.request = httpx.Request( # self.request = httpx.Request(
method="POST", url="https://api.together.xyz/inference" # method="POST", url="https://api.together.xyz/inference"
) # )
self.response = httpx.Response(status_code=status_code, request=self.request) # self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( # super().__init__(
self.message # self.message
) # Call the base class constructor with the parameters it needs # ) # Call the base class constructor with the parameters it needs
class TogetherAIConfig: # class TogetherAIConfig:
""" # """
Reference: https://docs.together.ai/reference/inference # Reference: https://docs.together.ai/reference/inference
The class `TogetherAIConfig` provides configuration for the TogetherAI's API interface. Here are the parameters: # The class `TogetherAIConfig` provides configuration for the TogetherAI's API interface. Here are the parameters:
- `max_tokens` (int32, required): The maximum number of tokens to generate. # - `max_tokens` (int32, required): The maximum number of tokens to generate.
- `stop` (string, optional): A string sequence that will truncate (stop) the inference text output. For example, "\n\n" will stop generation as soon as the model generates two newlines. # - `stop` (string, optional): A string sequence that will truncate (stop) the inference text output. For example, "\n\n" will stop generation as soon as the model generates two newlines.
- `temperature` (float, optional): A decimal number that determines the degree of randomness in the response. A value of 1 will always yield the same output. A temperature less than 1 favors more correctness and is appropriate for question answering or summarization. A value greater than 1 introduces more randomness in the output. # - `temperature` (float, optional): A decimal number that determines the degree of randomness in the response. A value of 1 will always yield the same output. A temperature less than 1 favors more correctness and is appropriate for question answering or summarization. A value greater than 1 introduces more randomness in the output.
- `top_p` (float, optional): The `top_p` (nucleus) parameter is used to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. It specifies a probability threshold, below which all less likely tokens are filtered out. This technique helps to maintain diversity and generate more fluent and natural-sounding text. # - `top_p` (float, optional): The `top_p` (nucleus) parameter is used to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. It specifies a probability threshold, below which all less likely tokens are filtered out. This technique helps to maintain diversity and generate more fluent and natural-sounding text.
- `top_k` (int32, optional): The `top_k` parameter is used to limit the number of choices for the next predicted word or token. It specifies the maximum number of tokens to consider at each step, based on their probability of occurrence. This technique helps to speed up the generation process and can improve the quality of the generated text by focusing on the most likely options. # - `top_k` (int32, optional): The `top_k` parameter is used to limit the number of choices for the next predicted word or token. It specifies the maximum number of tokens to consider at each step, based on their probability of occurrence. This technique helps to speed up the generation process and can improve the quality of the generated text by focusing on the most likely options.
- `repetition_penalty` (float, optional): A number that controls the diversity of generated text by reducing the likelihood of repeated sequences. Higher values decrease repetition. # - `repetition_penalty` (float, optional): A number that controls the diversity of generated text by reducing the likelihood of repeated sequences. Higher values decrease repetition.
- `logprobs` (int32, optional): This parameter is not described in the prompt. # - `logprobs` (int32, optional): This parameter is not described in the prompt.
""" # """
max_tokens: Optional[int] = None # max_tokens: Optional[int] = None
stop: Optional[str] = None # stop: Optional[str] = None
temperature: Optional[int] = None # temperature: Optional[int] = None
top_p: Optional[float] = None # top_p: Optional[float] = None
top_k: Optional[int] = None # top_k: Optional[int] = None
repetition_penalty: Optional[float] = None # repetition_penalty: Optional[float] = None
logprobs: Optional[int] = None # logprobs: Optional[int] = None
def __init__( # def __init__(
self, # self,
max_tokens: Optional[int] = None, # max_tokens: Optional[int] = None,
stop: Optional[str] = None, # stop: Optional[str] = None,
temperature: Optional[int] = None, # temperature: Optional[int] = None,
top_p: Optional[float] = None, # top_p: Optional[float] = None,
top_k: Optional[int] = None, # top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None, # repetition_penalty: Optional[float] = None,
logprobs: Optional[int] = None, # logprobs: Optional[int] = None,
) -> None: # ) -> None:
locals_ = locals() # locals_ = locals()
for key, value in locals_.items(): # for key, value in locals_.items():
if key != "self" and value is not None: # if key != "self" and value is not None:
setattr(self.__class__, key, value) # setattr(self.__class__, key, value)
@classmethod # @classmethod
def get_config(cls): # def get_config(cls):
return { # return {
k: v # k: v
for k, v in cls.__dict__.items() # for k, v in cls.__dict__.items()
if not k.startswith("__") # if not k.startswith("__")
and not isinstance( # and not isinstance(
v, # v,
( # (
types.FunctionType, # types.FunctionType,
types.BuiltinFunctionType, # types.BuiltinFunctionType,
classmethod, # classmethod,
staticmethod, # staticmethod,
), # ),
) # )
and v is not None # and v is not None
} # }
def validate_environment(api_key): # def validate_environment(api_key):
if api_key is None: # if api_key is None:
raise ValueError( # raise ValueError(
"Missing TogetherAI API Key - A call is being made to together_ai but no key is set either in the environment variables or via params" # "Missing TogetherAI API Key - A call is being made to together_ai but no key is set either in the environment variables or via params"
) # )
headers = { # headers = {
"accept": "application/json", # "accept": "application/json",
"content-type": "application/json", # "content-type": "application/json",
"Authorization": "Bearer " + api_key, # "Authorization": "Bearer " + api_key,
} # }
return headers # return headers
def completion( # def completion(
model: str, # model: str,
messages: list, # messages: list,
api_base: str, # api_base: str,
model_response: ModelResponse, # model_response: ModelResponse,
print_verbose: Callable, # print_verbose: Callable,
encoding, # encoding,
api_key, # api_key,
logging_obj, # logging_obj,
custom_prompt_dict={}, # custom_prompt_dict={},
optional_params=None, # optional_params=None,
litellm_params=None, # litellm_params=None,
logger_fn=None, # logger_fn=None,
): # ):
headers = validate_environment(api_key) # headers = validate_environment(api_key)
## Load Config # ## Load Config
config = litellm.TogetherAIConfig.get_config() # config = litellm.TogetherAIConfig.get_config()
for k, v in config.items(): # for k, v in config.items():
if ( # if (
k not in optional_params # k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in # ): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v # optional_params[k] = v
print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}") # print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
if model in custom_prompt_dict: # if model in custom_prompt_dict:
# check if the model has a registered custom prompt # # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] # model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( # prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}), # role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), # initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), # final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""), # bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""), # eos_token=model_prompt_details.get("eos_token", ""),
messages=messages, # messages=messages,
) # )
else: # else:
prompt = prompt_factory( # prompt = prompt_factory(
model=model, # model=model,
messages=messages, # messages=messages,
api_key=api_key, # api_key=api_key,
custom_llm_provider="together_ai", # custom_llm_provider="together_ai",
) # api key required to query together ai model list # ) # api key required to query together ai model list
data = { # data = {
"model": model, # "model": model,
"prompt": prompt, # "prompt": prompt,
"request_type": "language-model-inference", # "request_type": "language-model-inference",
**optional_params, # **optional_params,
} # }
## LOGGING # ## LOGGING
logging_obj.pre_call( # logging_obj.pre_call(
input=prompt, # input=prompt,
api_key=api_key, # api_key=api_key,
additional_args={ # additional_args={
"complete_input_dict": data, # "complete_input_dict": data,
"headers": headers, # "headers": headers,
"api_base": api_base, # "api_base": api_base,
}, # },
) # )
## COMPLETION CALL # ## COMPLETION CALL
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: # if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
response = requests.post( # response = requests.post(
api_base, # api_base,
headers=headers, # headers=headers,
data=json.dumps(data), # data=json.dumps(data),
stream=optional_params["stream_tokens"], # stream=optional_params["stream_tokens"],
) # )
return response.iter_lines() # return response.iter_lines()
else: # else:
response = requests.post(api_base, headers=headers, data=json.dumps(data)) # response = requests.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING # ## LOGGING
logging_obj.post_call( # logging_obj.post_call(
input=prompt, # input=prompt,
api_key=api_key, # api_key=api_key,
original_response=response.text, # original_response=response.text,
additional_args={"complete_input_dict": data}, # additional_args={"complete_input_dict": data},
) # )
print_verbose(f"raw model_response: {response.text}") # print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT # ## RESPONSE OBJECT
if response.status_code != 200: # if response.status_code != 200:
raise TogetherAIError( # raise TogetherAIError(
status_code=response.status_code, message=response.text # status_code=response.status_code, message=response.text
) # )
completion_response = response.json() # completion_response = response.json()
if "error" in completion_response: # if "error" in completion_response:
raise TogetherAIError( # raise TogetherAIError(
message=json.dumps(completion_response), # message=json.dumps(completion_response),
status_code=response.status_code, # status_code=response.status_code,
) # )
elif "error" in completion_response["output"]: # elif "error" in completion_response["output"]:
raise TogetherAIError( # raise TogetherAIError(
message=json.dumps(completion_response["output"]), # message=json.dumps(completion_response["output"]),
status_code=response.status_code, # status_code=response.status_code,
) # )
if len(completion_response["output"]["choices"][0]["text"]) >= 0: # if len(completion_response["output"]["choices"][0]["text"]) >= 0:
model_response["choices"][0]["message"]["content"] = completion_response[ # model_response.choices[0].message.content = completion_response["output"][
"output" # "choices"
]["choices"][0]["text"] # ][0]["text"]
## CALCULATING USAGE # ## CALCULATING USAGE
print_verbose( # print_verbose(
f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}" # f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
) # )
prompt_tokens = len(encoding.encode(prompt)) # prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len( # completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) # encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) # )
if "finish_reason" in completion_response["output"]["choices"][0]: # if "finish_reason" in completion_response["output"]["choices"][0]:
model_response.choices[0].finish_reason = completion_response["output"][ # model_response.choices[0].finish_reason = completion_response["output"][
"choices" # "choices"
][0]["finish_reason"] # ][0]["finish_reason"]
model_response["created"] = int(time.time()) # model_response["created"] = int(time.time())
model_response["model"] = "together_ai/" + model # model_response["model"] = "together_ai/" + model
usage = Usage( # usage = Usage(
prompt_tokens=prompt_tokens, # prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, # completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, # total_tokens=prompt_tokens + completion_tokens,
) # )
setattr(model_response, "usage", usage) # setattr(model_response, "usage", usage)
return model_response # return model_response
def embedding(): # def embedding():
# logic for parsing in - calling - parsing out model embedding calls # # logic for parsing in - calling - parsing out model embedding calls
pass # pass

View file

@ -852,16 +852,14 @@ def completion(
## RESPONSE OBJECT ## RESPONSE OBJECT
if isinstance(completion_response, litellm.Message): if isinstance(completion_response, litellm.Message):
model_response["choices"][0]["message"] = completion_response model_response.choices[0].message = completion_response # type: ignore
elif len(str(completion_response)) > 0: elif len(str(completion_response)) > 0:
model_response["choices"][0]["message"]["content"] = str( model_response.choices[0].message.content = str(completion_response) # type: ignore
completion_response model_response.created = int(time.time())
) model_response.model = model
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None: if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = map_finish_reason( model_response.choices[0].finish_reason = map_finish_reason(
response_obj.candidates[0].finish_reason.name response_obj.candidates[0].finish_reason.name
) )
usage = Usage( usage = Usage(
@ -912,7 +910,7 @@ async def async_completion(
request_str: str, request_str: str,
print_verbose: Callable, print_verbose: Callable,
logging_obj, logging_obj,
encoding=None, encoding,
client_options=None, client_options=None,
instances=None, instances=None,
vertex_project=None, vertex_project=None,
@ -1088,16 +1086,16 @@ async def async_completion(
## RESPONSE OBJECT ## RESPONSE OBJECT
if isinstance(completion_response, litellm.Message): if isinstance(completion_response, litellm.Message):
model_response["choices"][0]["message"] = completion_response model_response.choices[0].message = completion_response # type: ignore
elif len(str(completion_response)) > 0: elif len(str(completion_response)) > 0:
model_response["choices"][0]["message"]["content"] = str( model_response.choices[0].message.content = str( # type: ignore
completion_response completion_response
) )
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
## CALCULATING USAGE ## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None: if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = map_finish_reason( model_response.choices[0].finish_reason = map_finish_reason(
response_obj.candidates[0].finish_reason.name response_obj.candidates[0].finish_reason.name
) )
usage = Usage( usage = Usage(
@ -1377,16 +1375,16 @@ class VertexAITextEmbeddingConfig(BaseModel):
def embedding( def embedding(
model: str, model: str,
input: Union[list, str], input: Union[list, str],
print_verbose,
model_response: litellm.EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
logging_obj=None, logging_obj=None,
model_response=None,
optional_params=None,
encoding=None, encoding=None,
vertex_project=None, vertex_project=None,
vertex_location=None, vertex_location=None,
vertex_credentials=None, vertex_credentials=None,
aembedding=False, aembedding=False,
print_verbose=None,
): ):
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
try: try:
@ -1484,15 +1482,15 @@ def embedding(
"embedding": embedding.values, "embedding": embedding.values,
} }
) )
input_tokens += embedding.statistics.token_count input_tokens += embedding.statistics.token_count # type: ignore
model_response["object"] = "list" model_response.object = "list"
model_response["data"] = embedding_response model_response.data = embedding_response
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
) )
model_response.usage = usage setattr(model_response, "usage", usage)
return model_response return model_response
@ -1500,8 +1498,8 @@ def embedding(
async def async_embedding( async def async_embedding(
model: str, model: str,
input: Union[list, str], input: Union[list, str],
model_response: litellm.EmbeddingResponse,
logging_obj=None, logging_obj=None,
model_response=None,
optional_params=None, optional_params=None,
encoding=None, encoding=None,
client=None, client=None,
@ -1541,11 +1539,11 @@ async def async_embedding(
) )
input_tokens += embedding.statistics.token_count input_tokens += embedding.statistics.token_count
model_response["object"] = "list" model_response.object = "list"
model_response["data"] = embedding_response model_response.data = embedding_response
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
) )
model_response.usage = usage setattr(model_response, "usage", usage)
return model_response return model_response

View file

@ -367,8 +367,8 @@ async def async_completion(
prompt_tokens = message.usage.input_tokens prompt_tokens = message.usage.input_tokens
completion_tokens = message.usage.output_tokens completion_tokens = message.usage.output_tokens
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -1,11 +1,15 @@
import os
import json import json
import os
import time # type: ignore
from enum import Enum from enum import Enum
from typing import Any, Callable
import httpx
import requests # type: ignore import requests # type: ignore
import time, httpx # type: ignore
from typing import Callable, Any
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
from .prompt_templates.factory import custom_prompt, prompt_factory
llm = None llm = None
@ -91,14 +95,14 @@ def completion(
) )
print_verbose(f"raw model_response: {outputs}") print_verbose(f"raw model_response: {outputs}")
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = outputs[0].outputs[0].text model_response.choices[0].message.content = outputs[0].outputs[0].text # type: ignore
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = len(outputs[0].prompt_token_ids) prompt_tokens = len(outputs[0].prompt_token_ids)
completion_tokens = len(outputs[0].outputs[0].token_ids) completion_tokens = len(outputs[0].outputs[0].token_ids)
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -173,14 +177,14 @@ def batch_completions(
for output in outputs: for output in outputs:
model_response = ModelResponse() model_response = ModelResponse()
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = output.outputs[0].text model_response.choices[0].message.content = output.outputs[0].text # type: ignore
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = len(output.prompt_token_ids) prompt_tokens = len(output.prompt_token_ids)
completion_tokens = len(output.outputs[0].token_ids) completion_tokens = len(output.outputs[0].token_ids)
model_response["created"] = int(time.time()) model_response.created = int(time.time())
model_response["model"] = model model_response.model = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -591,9 +591,9 @@ class IBMWatsonXAI(BaseLLM):
self, self,
model: str, model: str,
input: Union[list, str], input: Union[list, str],
model_response: litellm.EmbeddingResponse,
api_key: Optional[str] = None, api_key: Optional[str] = None,
logging_obj=None, logging_obj=None,
model_response=None,
optional_params=None, optional_params=None,
encoding=None, encoding=None,
print_verbose=None, print_verbose=None,
@ -610,7 +610,7 @@ class IBMWatsonXAI(BaseLLM):
if k not in optional_params: if k not in optional_params:
optional_params[k] = v optional_params[k] = v
model_response["model"] = model model_response.model = model
# Load auth variables from environment variables # Load auth variables from environment variables
if isinstance(input, str): if isinstance(input, str):