forked from phoenix/litellm-mirror
working write to cache with streaming
This commit is contained in:
parent
2ff82873e6
commit
8af6d967eb
4 changed files with 84 additions and 46 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
import litellm
|
||||||
def get_prompt(*args, **kwargs):
|
def get_prompt(*args, **kwargs):
|
||||||
# make this safe checks, it should not throw any exceptions
|
# make this safe checks, it should not throw any exceptions
|
||||||
if len(args) > 1:
|
if len(args) > 1:
|
||||||
|
@ -29,8 +30,9 @@ class InMemoryCache():
|
||||||
self.cache_dict = {}
|
self.cache_dict = {}
|
||||||
|
|
||||||
def set_cache(self, key, value):
|
def set_cache(self, key, value):
|
||||||
#print("in set cache for inmem")
|
print("in set cache for inmem")
|
||||||
self.cache_dict[key] = value
|
self.cache_dict[key] = value
|
||||||
|
print(self.cache_dict)
|
||||||
|
|
||||||
def get_cache(self, key):
|
def get_cache(self, key):
|
||||||
#print("in get cache for inmem")
|
#print("in get cache for inmem")
|
||||||
|
@ -46,6 +48,8 @@ class Cache():
|
||||||
self.cache = RedisCache(type, host, port, password)
|
self.cache = RedisCache(type, host, port, password)
|
||||||
if type == "local":
|
if type == "local":
|
||||||
self.cache = InMemoryCache()
|
self.cache = InMemoryCache()
|
||||||
|
litellm.input_callback.append("cache")
|
||||||
|
litellm.success_callback.append("cache")
|
||||||
|
|
||||||
def get_cache_key(self, *args, **kwargs):
|
def get_cache_key(self, *args, **kwargs):
|
||||||
prompt = get_prompt(*args, **kwargs)
|
prompt = get_prompt(*args, **kwargs)
|
||||||
|
@ -71,8 +75,11 @@ class Cache():
|
||||||
|
|
||||||
def add_cache(self, result, *args, **kwargs):
|
def add_cache(self, result, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
|
# print("adding to cache", result)
|
||||||
cache_key = self.get_cache_key(*args, **kwargs)
|
cache_key = self.get_cache_key(*args, **kwargs)
|
||||||
|
# print(cache_key)
|
||||||
if cache_key is not None:
|
if cache_key is not None:
|
||||||
|
# print("adding to cache", cache_key, result)
|
||||||
self.cache.set_cache(cache_key, result)
|
self.cache.set_cache(cache_key, result)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -196,7 +196,7 @@ def completion(
|
||||||
engine=model, messages=messages, **optional_params
|
engine=model, messages=messages, **optional_params
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
response = CustomStreamWrapper(response, model)
|
response = CustomStreamWrapper(response, model, logging_obj=logging)
|
||||||
return response
|
return response
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
@ -254,7 +254,7 @@ def completion(
|
||||||
model=model, messages=messages, **optional_params
|
model=model, messages=messages, **optional_params
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
response = CustomStreamWrapper(response, model)
|
response = CustomStreamWrapper(response, model, logging_obj=logging)
|
||||||
return response
|
return response
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
|
|
@ -153,15 +153,19 @@ def test_embedding_caching():
|
||||||
|
|
||||||
|
|
||||||
# # test caching with streaming
|
# # test caching with streaming
|
||||||
# messages = [{"role": "user", "content": "draft a 2 pg legal document on applying to litellm"}]
|
# messages = [{"role": "user", "content": "hello gm who are u"}]
|
||||||
# def test_caching_v2_stream():
|
# def test_caching_v2_stream():
|
||||||
# try:
|
# try:
|
||||||
# litellm.cache = Cache()
|
# litellm.cache = Cache()
|
||||||
|
# # litellm.token="ishaan@berri.ai"
|
||||||
# response1 = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
# response1 = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||||
# for chunk in response1:
|
# for chunk in response1:
|
||||||
# print(chunk)
|
# #
|
||||||
# response1_id = chunk['id']
|
# pass
|
||||||
|
# # print("chunk")
|
||||||
|
# pass
|
||||||
|
# # response1_id = chunk['id']
|
||||||
|
|
||||||
# # response2 = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
# # response2 = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||||
# # for chunk in response2:
|
# # for chunk in response2:
|
||||||
# # #print(chunk)
|
# # #print(chunk)
|
||||||
|
|
105
litellm/utils.py
105
litellm/utils.py
|
@ -70,7 +70,7 @@ last_fetched_at_keys = None
|
||||||
|
|
||||||
|
|
||||||
class Message(OpenAIObject):
|
class Message(OpenAIObject):
|
||||||
def __init__(self, content="default", role="assistant", **params):
|
def __init__(self, content=" ", role="assistant", **params):
|
||||||
super(Message, self).__init__(**params)
|
super(Message, self).__init__(**params)
|
||||||
self.content = content
|
self.content = content
|
||||||
self.role = role
|
self.role = role
|
||||||
|
@ -285,6 +285,26 @@ class Logging:
|
||||||
call_type = self.call_type,
|
call_type = self.call_type,
|
||||||
stream = self.stream
|
stream = self.stream
|
||||||
)
|
)
|
||||||
|
if callback == "cache":
|
||||||
|
try:
|
||||||
|
#print("in cache callback2", self.stream)
|
||||||
|
#print(original_response)
|
||||||
|
#print(self.model_call_details)
|
||||||
|
|
||||||
|
if litellm.cache != None:
|
||||||
|
if self.litellm_params["stream_response"] == None:
|
||||||
|
self.litellm_params["stream_response"] = ModelResponse()
|
||||||
|
else:
|
||||||
|
#self.litellm_call_id["stream_response"]["id"] = self.litellm_params["litellm_call_id"]
|
||||||
|
self.litellm_params["stream_response"]["choices"][0]["message"]["content"] += original_response
|
||||||
|
#print("cache is not none")
|
||||||
|
# convert original_response to format of Model Object
|
||||||
|
# Set the model
|
||||||
|
litellm.cache.add_cache(self.litellm_params["stream_response"], **self.model_call_details)
|
||||||
|
#print(self.litellm_params["stream_response"])
|
||||||
|
except Exception as e:
|
||||||
|
print("got exception")
|
||||||
|
print(e)
|
||||||
except:
|
except:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while post-call logging with integrations {traceback.format_exc()}"
|
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while post-call logging with integrations {traceback.format_exc()}"
|
||||||
|
@ -624,6 +644,7 @@ def get_litellm_params(
|
||||||
"custom_api_base": custom_api_base,
|
"custom_api_base": custom_api_base,
|
||||||
"litellm_call_id": litellm_call_id,
|
"litellm_call_id": litellm_call_id,
|
||||||
"model_alias_map": model_alias_map,
|
"model_alias_map": model_alias_map,
|
||||||
|
"stream_response": None
|
||||||
}
|
}
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
@ -1576,7 +1597,10 @@ class CustomStreamWrapper:
|
||||||
return chunk["choices"][0]["text"]
|
return chunk["choices"][0]["text"]
|
||||||
except:
|
except:
|
||||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||||
|
|
||||||
|
def handle_openai_chat_completion_chunk(self, chunk):
|
||||||
|
return chunk["choices"][0]["delta"]["content"]
|
||||||
|
|
||||||
def handle_baseten_chunk(self, chunk):
|
def handle_baseten_chunk(self, chunk):
|
||||||
chunk = chunk.decode("utf-8")
|
chunk = chunk.decode("utf-8")
|
||||||
data_json = json.loads(chunk)
|
data_json = json.loads(chunk)
|
||||||
|
@ -1593,44 +1617,47 @@ class CustomStreamWrapper:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
completion_obj = {"role": "assistant", "content": ""}
|
try:
|
||||||
if self.model in litellm.anthropic_models:
|
completion_obj = {"role": "assistant", "content": ""}
|
||||||
chunk = next(self.completion_stream)
|
if self.model in litellm.anthropic_models:
|
||||||
completion_obj["content"] = self.handle_anthropic_chunk(chunk)
|
chunk = next(self.completion_stream)
|
||||||
elif self.model == "replicate":
|
completion_obj["content"] = self.handle_anthropic_chunk(chunk)
|
||||||
chunk = next(self.completion_stream)
|
elif self.model == "replicate":
|
||||||
completion_obj["content"] = chunk
|
chunk = next(self.completion_stream)
|
||||||
elif (
|
completion_obj["content"] = chunk
|
||||||
self.custom_llm_provider and self.custom_llm_provider == "together_ai"
|
elif (
|
||||||
) or ("togethercomputer" in self.model):
|
self.custom_llm_provider and self.custom_llm_provider == "together_ai"
|
||||||
chunk = next(self.completion_stream)
|
) or ("togethercomputer" in self.model):
|
||||||
text_data = self.handle_together_ai_chunk(chunk)
|
chunk = next(self.completion_stream)
|
||||||
if text_data == "":
|
text_data = self.handle_together_ai_chunk(chunk)
|
||||||
return self.__next__()
|
if text_data == "":
|
||||||
completion_obj["content"] = text_data
|
return self.__next__()
|
||||||
elif self.model in litellm.cohere_models:
|
completion_obj["content"] = text_data
|
||||||
chunk = next(self.completion_stream)
|
elif self.model in litellm.cohere_models:
|
||||||
completion_obj["content"] = chunk.text
|
chunk = next(self.completion_stream)
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
|
completion_obj["content"] = chunk.text
|
||||||
chunk = next(self.completion_stream)
|
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
|
||||||
completion_obj["content"] = self.handle_huggingface_chunk(chunk)
|
chunk = next(self.completion_stream)
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming
|
completion_obj["content"] = self.handle_huggingface_chunk(chunk)
|
||||||
chunk = next(self.completion_stream)
|
elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming
|
||||||
completion_obj["content"] = self.handle_baseten_chunk(chunk)
|
chunk = next(self.completion_stream)
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
|
completion_obj["content"] = self.handle_baseten_chunk(chunk)
|
||||||
chunk = next(self.completion_stream)
|
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
|
||||||
completion_obj["content"] = self.handle_ai21_chunk(chunk)
|
chunk = next(self.completion_stream)
|
||||||
elif self.model in litellm.open_ai_text_completion_models:
|
completion_obj["content"] = self.handle_ai21_chunk(chunk)
|
||||||
chunk = next(self.completion_stream)
|
elif self.model in litellm.open_ai_text_completion_models:
|
||||||
completion_obj["content"] = self.handle_openai_text_completion_chunk(chunk)
|
chunk = next(self.completion_stream)
|
||||||
else: # openai chat/azure models
|
completion_obj["content"] = self.handle_openai_text_completion_chunk(chunk)
|
||||||
chunk = next(self.completion_stream)
|
else: # openai chat/azure models
|
||||||
return chunk
|
chunk = next(self.completion_stream)
|
||||||
|
completion_obj["content"] = self.handle_openai_chat_completion_chunk(chunk)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
self.logging_obj.post_call(completion_obj["content"])
|
self.logging_obj.post_call(completion_obj["content"])
|
||||||
# return this for all models
|
# return this for all models
|
||||||
return {"choices": [{"delta": completion_obj}]}
|
return {"choices": [{"delta": completion_obj}]}
|
||||||
|
except:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
|
||||||
########## Reading Config File ############################
|
########## Reading Config File ############################
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue