refactor: fixing linting issues

This commit is contained in:
Krrish Dholakia 2023-11-11 18:52:28 -08:00
parent ae35c13015
commit 45b6f8b853
25 changed files with 223 additions and 133 deletions

View file

@ -76,8 +76,8 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
self.llm_provider = llm_provider self.llm_provider = llm_provider
super().__init__( super().__init__(
message=self.message, message=self.message,
model=self.model, model=self.model, # type: ignore
llm_provider=self.llm_provider, llm_provider=self.llm_provider, # type: ignore
response=response response=response
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -101,7 +101,7 @@ class APIError(APIError): # type: ignore
self.model = model self.model = model
super().__init__( super().__init__(
self.message, self.message,
request=request request=request # type: ignore
) )
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made

View file

@ -5,7 +5,7 @@ import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message, Usage
import httpx import httpx
class AlephAlphaError(Exception): class AlephAlphaError(Exception):
@ -265,9 +265,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Usage
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx import httpx
@ -167,9 +167,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():

View file

@ -7,11 +7,17 @@ from litellm import OpenAIConfig
import httpx import httpx
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
def __init__(self, status_code, message, request: httpx.Request, response: httpx.Response): def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
if request:
self.request = request self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response self.response = response
else:
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -136,7 +142,7 @@ class AzureChatCompletion(BaseLLM):
headers=headers, headers=headers,
) )
if response.status_code != 200: if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response) raise AzureOpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT ## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response) return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
@ -172,7 +178,7 @@ class AzureChatCompletion(BaseLLM):
method="POST" method="POST"
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response) raise AzureOpenAIError(status_code=response.status_code, message=response.text)
completion_stream = response.iter_lines() completion_stream = response.iter_lines()
streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
@ -194,7 +200,7 @@ class AzureChatCompletion(BaseLLM):
method="POST" method="POST"
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response) raise AzureOpenAIError(status_code=response.status_code, message=response.text)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="azure",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="azure",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:

View file

@ -1,9 +1,10 @@
## This is a template base class to be used for adding new LLM providers via API calls ## This is a template base class to be used for adding new LLM providers via API calls
import litellm import litellm
import httpx, certifi, ssl import httpx, certifi, ssl
from typing import Optional
class BaseLLM: class BaseLLM:
_client_session = None _client_session: Optional[httpx.Client] = None
def create_client_session(self): def create_client_session(self):
if litellm.client_session: if litellm.client_session:
_client_session = litellm.client_session _client_session = litellm.client_session

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests import requests
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Usage
class BasetenError(Exception): class BasetenError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -136,9 +136,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():

View file

@ -4,7 +4,7 @@ from enum import Enum
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse, get_secret from litellm.utils import ModelResponse, get_secret, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx import httpx
@ -424,9 +424,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens = prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
except BedrockError as e: except BedrockError as e:
exception_mapping_worked = True exception_mapping_worked = True
@ -497,6 +500,11 @@ def embedding(
"total_tokens": input_tokens, "total_tokens": input_tokens,
} }
usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens + 0
)
model_response.usage = usage
return model_response return model_response

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests import requests
import time, traceback import time, traceback
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
import httpx import httpx
@ -186,9 +186,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def embedding( def embedding(

View file

@ -6,7 +6,7 @@ import httpx, requests
import time import time
import litellm import litellm
from typing import Callable, Dict, List, Any from typing import Callable, Dict, List, Any
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -381,9 +381,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
model_response._hidden_params["original_response"] = completion_response model_response._hidden_params["original_response"] = completion_response
return model_response return model_response
except HuggingfaceError as e: except HuggingfaceError as e:

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests import requests
import time, traceback import time, traceback
from typing import Callable, Optional, List from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
class MaritalkError(Exception): class MaritalkError(Exception):
@ -145,9 +145,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def embedding( def embedding(

View file

@ -5,7 +5,7 @@ import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Usage
class NLPCloudError(Exception): class NLPCloudError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -171,9 +171,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class OobaboogaError(Exception): class OobaboogaError(Exception):
@ -111,9 +111,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():

View file

@ -2,16 +2,22 @@ from typing import Optional, Union
import types import types
import httpx import httpx
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object, Usage
from typing import Callable, Optional from typing import Callable, Optional
import aiohttp import aiohttp
class OpenAIError(Exception): class OpenAIError(Exception):
def __init__(self, status_code, message, request: httpx.Request, response: httpx.Response): def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
if request:
self.request = request self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response self.response = response
else:
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -264,13 +270,13 @@ class OpenAIChatCompletion(BaseLLM):
model: str model: str
): ):
with self._client_session.stream( with self._client_session.stream(
url=f"{api_base}", url=f"{api_base}", # type: ignore
json=data, json=data,
headers=headers, headers=headers,
method="POST" method="POST"
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response) raise OpenAIError(status_code=response.status_code, message=response.text()) # type: ignore
completion_stream = response.iter_lines() completion_stream = response.iter_lines()
streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
@ -292,7 +298,7 @@ class OpenAIChatCompletion(BaseLLM):
method="POST" method="POST"
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response) raise OpenAIError(status_code=response.status_code, message=response.text()) # type: ignore
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
@ -383,7 +389,7 @@ class OpenAITextCompletion(BaseLLM):
try: try:
## RESPONSE OBJECT ## RESPONSE OBJECT
if response_object is None or model_response_object is None: if response_object is None or model_response_object is None:
raise ValueError(message="Error in response object format") raise ValueError("Error in response object format")
choice_list=[] choice_list=[]
for idx, choice in enumerate(response_object["choices"]): for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["text"], role="assistant") message = Message(content=choice["text"], role="assistant")
@ -406,11 +412,11 @@ class OpenAITextCompletion(BaseLLM):
raise e raise e
def completion(self, def completion(self,
model: Optional[str]=None, model_response: ModelResponse,
messages: Optional[list]=None, api_key: str,
model_response: Optional[ModelResponse]=None, model: str,
messages: list,
print_verbose: Optional[Callable]=None, print_verbose: Optional[Callable]=None,
api_key: Optional[str]=None,
api_base: Optional[str]=None, api_base: Optional[str]=None,
logging_obj=None, logging_obj=None,
acompletion: bool = False, acompletion: bool = False,
@ -449,7 +455,7 @@ class OpenAITextCompletion(BaseLLM):
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else: else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else: else:
@ -459,7 +465,7 @@ class OpenAITextCompletion(BaseLLM):
headers=headers, headers=headers,
) )
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text, request=self._client_session.request, response=response) raise OpenAIError(status_code=response.status_code, message=response.text)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -521,7 +527,7 @@ class OpenAITextCompletion(BaseLLM):
method="POST" method="POST"
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response) raise OpenAIError(status_code=response.status_code, message=response.text)
streamwrapper = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
@ -542,7 +548,7 @@ class OpenAITextCompletion(BaseLLM):
method="POST" method="POST"
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response) raise OpenAIError(status_code=response.status_code, message=response.text)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:

View file

@ -3,7 +3,7 @@ import json
from enum import Enum from enum import Enum
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, get_secret, Choices, Message from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
import litellm import litellm
import sys import sys
@ -157,9 +157,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = "palm/" + model model_response["model"] = "palm/" + model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():

View file

@ -5,7 +5,7 @@ import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class PetalsError(Exception): class PetalsError(Exception):
@ -176,9 +176,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():

View file

@ -3,7 +3,7 @@ import json
import requests import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Usage
import litellm import litellm
import httpx import httpx
@ -261,9 +261,12 @@ def completion(
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(model_response["choices"][0]["message"].get("content", ""))) completion_tokens = len(encoding.encode(model_response["choices"][0]["message"].get("content", "")))
model_response["model"] = "replicate/" + model model_response["model"] = "replicate/" + model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response

View file

@ -5,7 +5,7 @@ import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse, get_secret from litellm.utils import ModelResponse, get_secret, Usage
import sys import sys
from copy import deepcopy from copy import deepcopy
import httpx import httpx
@ -172,9 +172,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():

View file

@ -5,7 +5,7 @@ import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class TogetherAIError(Exception): class TogetherAIError(Exception):
@ -182,9 +182,12 @@ def completion(
model_response.choices[0].finish_reason = completion_response["output"]["choices"][0]["finish_reason"] model_response.choices[0].finish_reason = completion_response["output"]["choices"][0]["finish_reason"]
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Usage
import litellm import litellm
class VertexAIError(Exception): class VertexAIError(Exception):
@ -150,10 +150,12 @@ def completion(
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
usage = Usage(
model_response.usage.completion_tokens = completion_tokens prompt_tokens=prompt_tokens,
model_response.usage.prompt_tokens = prompt_tokens completion_tokens=completion_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests import requests
import time import time
from typing import Callable, Any from typing import Callable, Any
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
llm = None llm = None
class VLLMError(Exception): class VLLMError(Exception):
@ -90,9 +90,12 @@ def completion(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response return model_response
def batch_completions( def batch_completions(
@ -170,9 +173,12 @@ def batch_completions(
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = model model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens usage = Usage(
model_response.usage.prompt_tokens = prompt_tokens prompt_tokens=prompt_tokens,
model_response.usage.total_tokens = prompt_tokens + completion_tokens completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
final_outputs.append(model_response) final_outputs.append(model_response)
return final_outputs return final_outputs

View file

@ -12,6 +12,7 @@ from typing import Any
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx
import litellm import litellm
from litellm import ( # type: ignore from litellm import ( # type: ignore
client, client,
@ -838,14 +839,14 @@ def completion(
) )
## COMPLETION CALL ## COMPLETION CALL
openai.api_key = api_key # set key for deep infra openai.api_key = api_key # set key for deep infra
openai.base_url = api_base # use the deepinfra api base
try: try:
response = openai.ChatCompletion.create( response = openai.chat.completions.create(
model=model, model=model, # type: ignore
messages=messages, messages=messages, # type: ignore
api_base=api_base, # use the deepinfra api base api_type="openai", # type: ignore
api_type="openai", api_version=api_version, # type: ignore
api_version=api_version, # default None **optional_params, # type: ignore
**optional_params,
) )
except Exception as e: except Exception as e:
## LOGGING - log the original exception returned ## LOGGING - log the original exception returned
@ -932,7 +933,7 @@ def completion(
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter": elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
openai.api_type = "openai" openai.api_type = "openai"
# not sure if this will work after someone first uses another API # not sure if this will work after someone first uses another API
openai.api_base = ( openai.base_url = (
litellm.api_base litellm.api_base
if litellm.api_base is not None if litellm.api_base is not None
else "https://openrouter.ai/api/v1" else "https://openrouter.ai/api/v1"
@ -963,9 +964,9 @@ def completion(
logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"complete_input_dict": data, "headers": headers}) logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"complete_input_dict": data, "headers": headers})
## COMPLETION CALL ## COMPLETION CALL
if headers: if headers:
response = openai.ChatCompletion.create( response = openai.chat.completions.create(
headers=headers, headers=headers, # type: ignore
**data, **data, # type: ignore
) )
else: else:
openrouter_site_url = get_secret("OR_SITE_URL") openrouter_site_url = get_secret("OR_SITE_URL")
@ -976,11 +977,11 @@ def completion(
# if openrouter_app_name is None, set it to liteLLM # if openrouter_app_name is None, set it to liteLLM
if openrouter_app_name is None: if openrouter_app_name is None:
openrouter_app_name = "liteLLM" openrouter_app_name = "liteLLM"
response = openai.ChatCompletion.create( response = openai.chat.completions.create( # type: ignore
headers={ extra_headers=httpx.Headers({ # type: ignore
"HTTP-Referer": openrouter_site_url, # To identify your site "HTTP-Referer": openrouter_site_url, # type: ignore
"X-Title": openrouter_app_name, # To identify your app "X-Title": openrouter_app_name, # type: ignore
}, }), # type: ignore
**data, **data,
) )
## LOGGING ## LOGGING
@ -1961,7 +1962,7 @@ def text_completion(
futures = [executor.submit(process_prompt, i, individual_prompt) for i, individual_prompt in enumerate(prompt)] futures = [executor.submit(process_prompt, i, individual_prompt) for i, individual_prompt in enumerate(prompt)]
for i, future in enumerate(concurrent.futures.as_completed(futures)): for i, future in enumerate(concurrent.futures.as_completed(futures)):
responses[i] = future.result() responses[i] = future.result()
text_completion_response["choices"] = responses text_completion_response.choices = responses
return text_completion_response return text_completion_response
# else: # else:
@ -2012,10 +2013,10 @@ def moderation(input: str, api_key: Optional[str]=None):
get_secret("OPENAI_API_KEY") get_secret("OPENAI_API_KEY")
) )
openai.api_key = api_key openai.api_key = api_key
openai.api_type = "open_ai" openai.api_type = "open_ai" # type: ignore
openai.api_version = None openai.api_version = None
openai.api_base = "https://api.openai.com/v1" openai.base_url = "https://api.openai.com/v1"
response = openai.Moderation.create(input) response = openai.moderations.create(input=input)
return response return response
####### HELPER FUNCTIONS ################ ####### HELPER FUNCTIONS ################

View file

@ -2,9 +2,9 @@
# it makes async Completion requests with streaming # it makes async Completion requests with streaming
import openai import openai
openai.api_base = "http://0.0.0.0:8000" openai.base_url = "http://0.0.0.0:8000"
openai.api_key = "temp-key" openai.api_key = "temp-key"
print(openai.api_base) print(openai.base_url)
async def test_async_completion(): async def test_async_completion():
response = await openai.Completion.acreate( response = await openai.Completion.acreate(

View file

@ -1,8 +1,4 @@
try: from openai import AuthenticationError, BadRequestError, RateLimitError, OpenAIError
from openai import AuthenticationError, BadRequestError, RateLimitError, OpenAIError
except:
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, OpenAIError
import os import os
import sys import sys
import traceback import traceback

View file

@ -17,10 +17,7 @@ from concurrent import futures
from inspect import iscoroutinefunction from inspect import iscoroutinefunction
from functools import wraps from functools import wraps
from threading import Thread from threading import Thread
try: from openai import Timeout
from openai import Timeout
except:
from openai.error import Timeout
def timeout(timeout_duration: float = 0.0, exception_to_raise=Timeout): def timeout(timeout_duration: float = 0.0, exception_to_raise=Timeout):

View file

@ -39,12 +39,8 @@ from .integrations.weights_biases import WeightsBiasesLogger
from .integrations.custom_logger import CustomLogger from .integrations.custom_logger import CustomLogger
from .integrations.langfuse import LangFuseLogger from .integrations.langfuse import LangFuseLogger
from .integrations.litedebugger import LiteDebugger from .integrations.litedebugger import LiteDebugger
try: from openai import OpenAIError as OriginalError
from openai import OpenAIError as OriginalError from openai._models import BaseModel as OpenAIObject
from openai._models import BaseModel as OpenAIObject
except:
from openai.error import OpenAIError as OriginalError
from openai.openai_object import OpenAIObject
from .exceptions import ( from .exceptions import (
AuthenticationError, AuthenticationError,
BadRequestError, BadRequestError,
@ -354,6 +350,22 @@ class TextChoices(OpenAIObject):
else: else:
self.logprobs = logprobs self.logprobs = logprobs
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class TextCompletionResponse(OpenAIObject): class TextCompletionResponse(OpenAIObject):
""" """
{ {
@ -399,6 +411,22 @@ class TextCompletionResponse(OpenAIObject):
self._hidden_params = {} # used in case users want to access the original model response self._hidden_params = {} # used in case users want to access the original model response
super(TextCompletionResponse, self).__init__(**params) super(TextCompletionResponse, self).__init__(**params)
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
############################################################ ############################################################
def print_verbose(print_statement): def print_verbose(print_statement):
if litellm.set_verbose: if litellm.set_verbose: