working write to cache with streaming

This commit is contained in:
ishaan-jaff 2023-08-28 18:46:25 -07:00
parent 2ff82873e6
commit 8af6d967eb
4 changed files with 84 additions and 46 deletions

View file

@ -1,3 +1,4 @@
import litellm
def get_prompt(*args, **kwargs): def get_prompt(*args, **kwargs):
# make this safe checks, it should not throw any exceptions # make this safe checks, it should not throw any exceptions
if len(args) > 1: if len(args) > 1:
@ -29,8 +30,9 @@ class InMemoryCache():
self.cache_dict = {} self.cache_dict = {}
def set_cache(self, key, value): def set_cache(self, key, value):
#print("in set cache for inmem") print("in set cache for inmem")
self.cache_dict[key] = value self.cache_dict[key] = value
print(self.cache_dict)
def get_cache(self, key): def get_cache(self, key):
#print("in get cache for inmem") #print("in get cache for inmem")
@ -46,6 +48,8 @@ class Cache():
self.cache = RedisCache(type, host, port, password) self.cache = RedisCache(type, host, port, password)
if type == "local": if type == "local":
self.cache = InMemoryCache() self.cache = InMemoryCache()
litellm.input_callback.append("cache")
litellm.success_callback.append("cache")
def get_cache_key(self, *args, **kwargs): def get_cache_key(self, *args, **kwargs):
prompt = get_prompt(*args, **kwargs) prompt = get_prompt(*args, **kwargs)
@ -71,8 +75,11 @@ class Cache():
def add_cache(self, result, *args, **kwargs): def add_cache(self, result, *args, **kwargs):
try: try:
# print("adding to cache", result)
cache_key = self.get_cache_key(*args, **kwargs) cache_key = self.get_cache_key(*args, **kwargs)
# print(cache_key)
if cache_key is not None: if cache_key is not None:
# print("adding to cache", cache_key, result)
self.cache.set_cache(cache_key, result) self.cache.set_cache(cache_key, result)
except: except:
pass pass

View file

@ -196,7 +196,7 @@ def completion(
engine=model, messages=messages, **optional_params engine=model, messages=messages, **optional_params
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model) response = CustomStreamWrapper(response, model, logging_obj=logging)
return response return response
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
@ -254,7 +254,7 @@ def completion(
model=model, messages=messages, **optional_params model=model, messages=messages, **optional_params
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model) response = CustomStreamWrapper(response, model, logging_obj=logging)
return response return response
## LOGGING ## LOGGING
logging.post_call( logging.post_call(

View file

@ -153,15 +153,19 @@ def test_embedding_caching():
# # test caching with streaming # # test caching with streaming
# messages = [{"role": "user", "content": "draft a 2 pg legal document on applying to litellm"}] # messages = [{"role": "user", "content": "hello gm who are u"}]
# def test_caching_v2_stream(): # def test_caching_v2_stream():
# try: # try:
# litellm.cache = Cache() # litellm.cache = Cache()
# # litellm.token="ishaan@berri.ai"
# response1 = completion(model="gpt-3.5-turbo", messages=messages, stream=True) # response1 = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
# for chunk in response1: # for chunk in response1:
# print(chunk) # #
# response1_id = chunk['id'] # pass
# # print("chunk")
# pass
# # response1_id = chunk['id']
# # response2 = completion(model="gpt-3.5-turbo", messages=messages, stream=True) # # response2 = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
# # for chunk in response2: # # for chunk in response2:
# # #print(chunk) # # #print(chunk)

View file

@ -70,7 +70,7 @@ last_fetched_at_keys = None
class Message(OpenAIObject): class Message(OpenAIObject):
def __init__(self, content="default", role="assistant", **params): def __init__(self, content=" ", role="assistant", **params):
super(Message, self).__init__(**params) super(Message, self).__init__(**params)
self.content = content self.content = content
self.role = role self.role = role
@ -285,6 +285,26 @@ class Logging:
call_type = self.call_type, call_type = self.call_type,
stream = self.stream stream = self.stream
) )
if callback == "cache":
try:
#print("in cache callback2", self.stream)
#print(original_response)
#print(self.model_call_details)
if litellm.cache != None:
if self.litellm_params["stream_response"] == None:
self.litellm_params["stream_response"] = ModelResponse()
else:
#self.litellm_call_id["stream_response"]["id"] = self.litellm_params["litellm_call_id"]
self.litellm_params["stream_response"]["choices"][0]["message"]["content"] += original_response
#print("cache is not none")
# convert original_response to format of Model Object
# Set the model
litellm.cache.add_cache(self.litellm_params["stream_response"], **self.model_call_details)
#print(self.litellm_params["stream_response"])
except Exception as e:
print("got exception")
print(e)
except: except:
print_verbose( print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while post-call logging with integrations {traceback.format_exc()}" f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while post-call logging with integrations {traceback.format_exc()}"
@ -624,6 +644,7 @@ def get_litellm_params(
"custom_api_base": custom_api_base, "custom_api_base": custom_api_base,
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"model_alias_map": model_alias_map, "model_alias_map": model_alias_map,
"stream_response": None
} }
return litellm_params return litellm_params
@ -1576,7 +1597,10 @@ class CustomStreamWrapper:
return chunk["choices"][0]["text"] return chunk["choices"][0]["text"]
except: except:
raise ValueError(f"Unable to parse response. Original response: {chunk}") raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_openai_chat_completion_chunk(self, chunk):
return chunk["choices"][0]["delta"]["content"]
def handle_baseten_chunk(self, chunk): def handle_baseten_chunk(self, chunk):
chunk = chunk.decode("utf-8") chunk = chunk.decode("utf-8")
data_json = json.loads(chunk) data_json = json.loads(chunk)
@ -1593,44 +1617,47 @@ class CustomStreamWrapper:
return "" return ""
def __next__(self): def __next__(self):
completion_obj = {"role": "assistant", "content": ""} try:
if self.model in litellm.anthropic_models: completion_obj = {"role": "assistant", "content": ""}
chunk = next(self.completion_stream) if self.model in litellm.anthropic_models:
completion_obj["content"] = self.handle_anthropic_chunk(chunk) chunk = next(self.completion_stream)
elif self.model == "replicate": completion_obj["content"] = self.handle_anthropic_chunk(chunk)
chunk = next(self.completion_stream) elif self.model == "replicate":
completion_obj["content"] = chunk chunk = next(self.completion_stream)
elif ( completion_obj["content"] = chunk
self.custom_llm_provider and self.custom_llm_provider == "together_ai" elif (
) or ("togethercomputer" in self.model): self.custom_llm_provider and self.custom_llm_provider == "together_ai"
chunk = next(self.completion_stream) ) or ("togethercomputer" in self.model):
text_data = self.handle_together_ai_chunk(chunk) chunk = next(self.completion_stream)
if text_data == "": text_data = self.handle_together_ai_chunk(chunk)
return self.__next__() if text_data == "":
completion_obj["content"] = text_data return self.__next__()
elif self.model in litellm.cohere_models: completion_obj["content"] = text_data
chunk = next(self.completion_stream) elif self.model in litellm.cohere_models:
completion_obj["content"] = chunk.text chunk = next(self.completion_stream)
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": completion_obj["content"] = chunk.text
chunk = next(self.completion_stream) elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
completion_obj["content"] = self.handle_huggingface_chunk(chunk) chunk = next(self.completion_stream)
elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming completion_obj["content"] = self.handle_huggingface_chunk(chunk)
chunk = next(self.completion_stream) elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming
completion_obj["content"] = self.handle_baseten_chunk(chunk) chunk = next(self.completion_stream)
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming completion_obj["content"] = self.handle_baseten_chunk(chunk)
chunk = next(self.completion_stream) elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
completion_obj["content"] = self.handle_ai21_chunk(chunk) chunk = next(self.completion_stream)
elif self.model in litellm.open_ai_text_completion_models: completion_obj["content"] = self.handle_ai21_chunk(chunk)
chunk = next(self.completion_stream) elif self.model in litellm.open_ai_text_completion_models:
completion_obj["content"] = self.handle_openai_text_completion_chunk(chunk) chunk = next(self.completion_stream)
else: # openai chat/azure models completion_obj["content"] = self.handle_openai_text_completion_chunk(chunk)
chunk = next(self.completion_stream) else: # openai chat/azure models
return chunk chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_openai_chat_completion_chunk(chunk)
# LOGGING # LOGGING
self.logging_obj.post_call(completion_obj["content"]) self.logging_obj.post_call(completion_obj["content"])
# return this for all models # return this for all models
return {"choices": [{"delta": completion_obj}]} return {"choices": [{"delta": completion_obj}]}
except:
raise StopIteration
########## Reading Config File ############################ ########## Reading Config File ############################