mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Update langchain-llama-stack.py
This commit is contained in:
parent
0da0732b07
commit
63375b8f45
1 changed files with 39 additions and 70 deletions
|
@ -1,19 +1,21 @@
|
|||
import html
|
||||
import os
|
||||
import re
|
||||
import html
|
||||
import tempfile
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from readability import Document as ReadabilityDocument
|
||||
from markdownify import markdownify
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_community.document_loaders import PyPDFLoader, TextLoader
|
||||
import tempfile
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from typing import Optional, List, Any
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from markdownify import markdownify
|
||||
from readability import Document as ReadabilityDocument
|
||||
from rich.pretty import pprint
|
||||
|
||||
# Global variables
|
||||
|
@ -31,7 +33,7 @@ summary_template = PromptTemplate(
|
|||
|
||||
{document}
|
||||
|
||||
SUMMARY:"""
|
||||
SUMMARY:""",
|
||||
)
|
||||
|
||||
facts_template = PromptTemplate(
|
||||
|
@ -41,7 +43,7 @@ facts_template = PromptTemplate(
|
|||
{document}
|
||||
|
||||
KEY FACTS:
|
||||
-"""
|
||||
-""",
|
||||
)
|
||||
|
||||
qa_template = PromptTemplate(
|
||||
|
@ -53,36 +55,13 @@ DOCUMENT:
|
|||
|
||||
QUESTION: {question}
|
||||
|
||||
ANSWER:"""
|
||||
ANSWER:""",
|
||||
)
|
||||
|
||||
class LlamaStackLLM(LLM):
|
||||
"""Simple LangChain wrapper for Llama Stack"""
|
||||
|
||||
# Pydantic model fields
|
||||
client: Any = None
|
||||
model_id: str = "llama3:70b-instruct"
|
||||
|
||||
def __init__(self, client, model_id: str = "llama3:70b-instruct"):
|
||||
# Initialize with field values
|
||||
super().__init__(client=client, model_id=model_id)
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
|
||||
"""Make inference call to Llama Stack"""
|
||||
response = self.client.inference.chat_completion(
|
||||
model_id=self.model_id,
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
return response.completion_message.content
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "llama_stack"
|
||||
|
||||
|
||||
def load_document(source: str) -> str:
|
||||
is_url = source.startswith(('http://', 'https://'))
|
||||
is_pdf = source.lower().endswith('.pdf')
|
||||
is_url = source.startswith(("http://", "https://"))
|
||||
is_pdf = source.lower().endswith(".pdf")
|
||||
if is_pdf:
|
||||
return load_pdf(source, is_url=is_url)
|
||||
elif is_url:
|
||||
|
@ -110,19 +89,22 @@ def load_pdf(source: str, is_url: bool = False) -> str:
|
|||
|
||||
|
||||
def load_from_url(url: str) -> str:
|
||||
headers = {'User-Agent': 'Mozilla/5.0 (compatible; DocumentLoader/1.0)'}
|
||||
headers = {"User-Agent": "Mozilla/5.0 (compatible; DocumentLoader/1.0)"}
|
||||
response = requests.get(url, headers=headers, timeout=15)
|
||||
response.raise_for_status()
|
||||
doc = ReadabilityDocument(response.text)
|
||||
html_main = doc.summary(html_partial=True)
|
||||
soup = BeautifulSoup(html_main, "html.parser")
|
||||
for tag in soup(["script", "style", "noscript", "header", "footer", "nav", "aside"]):
|
||||
for tag in soup(
|
||||
["script", "style", "noscript", "header", "footer", "nav", "aside"]
|
||||
):
|
||||
tag.decompose()
|
||||
md_text = markdownify(str(soup), heading_style="ATX")
|
||||
md_text = html.unescape(md_text)
|
||||
md_text = re.sub(r"\n{3,}", "\n\n", md_text).strip()
|
||||
return md_text
|
||||
|
||||
|
||||
def process_document(source: str):
|
||||
global summary_chain, facts_chain, processed_docs
|
||||
|
||||
|
@ -134,17 +116,14 @@ def process_document(source: str):
|
|||
print("Summary generated")
|
||||
print("🔍 Extracting key facts...")
|
||||
facts = facts_chain.invoke({"document": document})["text"]
|
||||
processed_docs[source] = {
|
||||
"document": document,
|
||||
"summary": summary,
|
||||
"facts": facts
|
||||
}
|
||||
processed_docs[source] = {"document": document, "summary": summary, "facts": facts}
|
||||
print(f"\n✅ Processing complete!")
|
||||
print(f"📊 Document: {len(document):,} chars")
|
||||
print(f"📝 Summary: {summary[:100]}...")
|
||||
print(f"🔍 Facts: {facts[:1000]}...")
|
||||
return processed_docs[source]
|
||||
|
||||
|
||||
def ask_question(question: str, source: str = None):
|
||||
"""Answer questions about processed documents"""
|
||||
global qa_chain, processed_docs
|
||||
|
@ -156,10 +135,9 @@ def ask_question(question: str, source: str = None):
|
|||
else:
|
||||
# Use the most recent document
|
||||
doc_data = list(processed_docs.values())[-1]
|
||||
answer = qa_chain.invoke({
|
||||
"document": doc_data["document"],
|
||||
"question": question
|
||||
})["text"]
|
||||
answer = qa_chain.invoke({"document": doc_data["document"], "question": question})[
|
||||
"text"
|
||||
]
|
||||
return answer
|
||||
|
||||
|
||||
|
@ -176,16 +154,16 @@ def interactive_demo():
|
|||
while True:
|
||||
try:
|
||||
command = input("\n> ").strip()
|
||||
if command.lower() in ['quit', 'exit']:
|
||||
if command.lower() in ["quit", "exit"]:
|
||||
print("👋 Thanks for exploring LangChain chains!")
|
||||
break
|
||||
elif command.lower() == 'help':
|
||||
elif command.lower() == "help":
|
||||
print("\nCommands:")
|
||||
print(" load <url_or_path> - Process a document")
|
||||
print(" ask <question> - Ask about the document")
|
||||
print(" summary - Show document summary")
|
||||
print(" facts - Show extracted facts")
|
||||
elif command.startswith('load '):
|
||||
elif command.startswith("load "):
|
||||
source = command[5:].strip()
|
||||
if source:
|
||||
try:
|
||||
|
@ -194,7 +172,7 @@ def interactive_demo():
|
|||
print(f"❌ Error processing document: {e}")
|
||||
else:
|
||||
print("❓ Please provide a URL or file path")
|
||||
elif command.startswith('ask '):
|
||||
elif command.startswith("ask "):
|
||||
question = command[4:].strip()
|
||||
if question:
|
||||
try:
|
||||
|
@ -205,13 +183,13 @@ def interactive_demo():
|
|||
print(f"❌ Error: {e}")
|
||||
else:
|
||||
print("❓ Please provide a question")
|
||||
elif command.lower() == 'summary':
|
||||
elif command.lower() == "summary":
|
||||
if processed_docs:
|
||||
latest_doc = list(processed_docs.values())[-1]
|
||||
print(f"\n📝 Summary:\n{latest_doc['summary']}")
|
||||
else:
|
||||
print("❓ No documents processed yet")
|
||||
elif command.lower() == 'facts':
|
||||
elif command.lower() == "facts":
|
||||
if processed_docs:
|
||||
latest_doc = list(processed_docs.values())[-1]
|
||||
print(f"\n🔍 Key Facts:\n{latest_doc['facts']}")
|
||||
|
@ -232,14 +210,14 @@ def main():
|
|||
client = LlamaStackClient(
|
||||
base_url="http://localhost:8321/",
|
||||
)
|
||||
|
||||
# Initialize the LangChain-compatible LLM
|
||||
llm = LlamaStackLLM(client)
|
||||
os.environ["OPENAI_API_KEY"] = "dummy"
|
||||
os.environ["OPENAI_BASE_URL"] = "http://0.0.0.0:8321/v1/openai/v1"
|
||||
llm = ChatOpenAI(model="ollama/llama3:70b-instruct")
|
||||
|
||||
# Test the wrapper
|
||||
test_response = llm.invoke("Can you help me with the document processing?")
|
||||
print(f"✅ LangChain wrapper working!")
|
||||
print(f"Response: {test_response[:100]}...")
|
||||
print(f"Response: {test_response.content[:100]}...")
|
||||
|
||||
print("Available models:")
|
||||
for m in client.models.list():
|
||||
|
@ -251,19 +229,7 @@ def main():
|
|||
print(s.identifier)
|
||||
print("----")
|
||||
|
||||
# model_id = "llama3.2:3b"
|
||||
model_id = "ollama/llama3:70b-instruct"
|
||||
|
||||
response = client.inference.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a friendly assistant."},
|
||||
{"role": "user", "content": "Write a two-sentence poem about llama."},
|
||||
],
|
||||
)
|
||||
|
||||
print(response.completion_message.content)
|
||||
|
||||
# Create chains by combining our LLM with prompt templates
|
||||
summary_chain = LLMChain(llm=llm, prompt=summary_template)
|
||||
facts_chain = LLMChain(llm=llm, prompt=facts_template)
|
||||
|
@ -278,11 +244,14 @@ def main():
|
|||
print(" • Q&A: Answers questions based on document content")
|
||||
|
||||
# Test template formatting
|
||||
test_prompt = summary_template.format(document="This is a sample document about AI...")
|
||||
test_prompt = summary_template.format(
|
||||
document="This is a sample document about AI..."
|
||||
)
|
||||
print(f"\n📝 Example prompt: {len(test_prompt)} characters")
|
||||
|
||||
# Start the interactive demo
|
||||
interactive_demo()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue