mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
v0 add TokenIterator, stream support
This commit is contained in:
parent
b6a6942867
commit
998094d38d
1 changed files with 49 additions and 1 deletions
|
@ -25,6 +25,33 @@ class SagemakerError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class TokenIterator:
|
||||||
|
def __init__(self, stream):
|
||||||
|
self.byte_iterator = iter(stream)
|
||||||
|
self.buffer = io.BytesIO()
|
||||||
|
self.read_pos = 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
while True:
|
||||||
|
self.buffer.seek(self.read_pos)
|
||||||
|
line = self.buffer.readline()
|
||||||
|
if line and line[-1] == ord("\n"):
|
||||||
|
self.read_pos += len(line) + 1
|
||||||
|
full_line = line[:-1].decode("utf-8")
|
||||||
|
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
|
||||||
|
return line_data["token"]["text"]
|
||||||
|
chunk = next(self.byte_iterator)
|
||||||
|
self.buffer.seek(0, io.SEEK_END)
|
||||||
|
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
||||||
|
|
||||||
|
|
||||||
class SagemakerConfig:
|
class SagemakerConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
|
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
|
||||||
|
@ -121,7 +148,6 @@ def completion(
|
||||||
|
|
||||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||||
inference_params = deepcopy(optional_params)
|
inference_params = deepcopy(optional_params)
|
||||||
inference_params.pop("stream", None)
|
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.SagemakerConfig.get_config()
|
config = litellm.SagemakerConfig.get_config()
|
||||||
|
@ -152,6 +178,28 @@ def completion(
|
||||||
hf_model_name or model
|
hf_model_name or model
|
||||||
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||||
|
stream = inference_params.pop("stream", None)
|
||||||
|
if stream == True:
|
||||||
|
data = json.dumps(
|
||||||
|
{"inputs": prompt, "parameters": inference_params, "stream": True}
|
||||||
|
).encode("utf-8")
|
||||||
|
## LOGGING
|
||||||
|
request_str = f"""
|
||||||
|
response = client.invoke_endpoint_with_response_stream(
|
||||||
|
EndpointName={model},
|
||||||
|
ContentType="application/json",
|
||||||
|
Body={data},
|
||||||
|
CustomAttributes="accept_eula=true",
|
||||||
|
)
|
||||||
|
""" # type: ignore
|
||||||
|
response = client.invoke_endpoint_with_response_stream(
|
||||||
|
EndpointName=model,
|
||||||
|
ContentType="application/json",
|
||||||
|
Body=data,
|
||||||
|
CustomAttributes="accept_eula=true",
|
||||||
|
)
|
||||||
|
|
||||||
|
return response["Body"]
|
||||||
|
|
||||||
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
|
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
|
||||||
"utf-8"
|
"utf-8"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue