Ryanfafa commited on
Commit
549d638
·
verified ·
1 Parent(s): 4733035

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +84 -17
rag_engine.py CHANGED
@@ -1,19 +1,21 @@
1
  """
2
- rag_engine.py — Multimodal RAG Engine with Multi-File Support & Conversation Memory
3
  Supports: PDF, TXT, DOCX, CSV, XLSX, Images (JPG/PNG/WEBP)
4
  Features: Up to 5 simultaneous files, per-file removal, additive indexing
5
  Memory: sliding window of last 6 exchanges
6
 
7
- KEY CHANGES (v4Multi-File):
8
- 1. Additive indexing new uploads ADD to the vectorstore, not replace it.
9
- 2. Per-file chunk tracking each file's chunk IDs are stored for clean removal.
10
- 3. remove_file(filename) deletes a specific file's chunks from the vectorstore.
11
- 4. MAX_FILES = 5 enforced in ingest_file().
12
- 5. _generate() is multi-file aware system prompt lists all loaded files.
13
- 6. query() scales retrieval k with number of files for cross-doc coverage.
14
- 7. Memory is NOT cleared on upload (user may be chatting about multiple docs).
15
-
16
- Keeps all v3 image fixes: OCR, color analysis, BLIP raw bytes, VLM descriptions.
 
 
17
  """
18
 
19
  import os
@@ -44,9 +46,11 @@ logger = logging.getLogger(__name__)
44
 
45
  # ── Constants ────────────────────────────────────────────────────────────────
46
  EMBED_MODEL = "all-MiniLM-L6-v2"
 
47
  CHUNK_SIZE = 600
48
  CHUNK_OVERLAP = 100
49
- TOP_K = 4
 
50
  COLLECTION_NAME = "docmind_multimodal"
51
  HF_API_URL = "https://router.huggingface.co/v1/chat/completions"
52
  MEMORY_WINDOW = 6 # number of past Q&A pairs to keep
@@ -119,6 +123,7 @@ def _classify_color(r: int, g: int, b: int) -> str:
119
  class RAGEngine:
120
  def __init__(self):
121
  self._embeddings: Optional[HuggingFaceEmbeddings] = None
 
122
  self._vectorstore: Optional[Chroma] = None
123
  self._splitter = RecursiveCharacterTextSplitter(
124
  chunk_size=CHUNK_SIZE,
@@ -140,6 +145,58 @@ class RAGEngine:
140
  )
141
  return self._embeddings
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # ── Memory ───────────────────────────────────────────────────────────────
144
 
145
  def clear_memory(self):
@@ -805,15 +862,25 @@ class RAGEngine:
805
  sources = []
806
 
807
  try:
808
- # Scale retrieval k with number of files for cross-doc coverage
809
- k = min(TOP_K + len(self._documents) - 1, 6)
810
- fetch_k = k * 3
 
 
 
811
 
812
  retriever = self._vectorstore.as_retriever(
813
  search_type="mmr",
814
- search_kwargs={"k": k, "fetch_k": fetch_k},
815
  )
816
- docs = retriever.invoke(question)
 
 
 
 
 
 
 
817
  context = "\n\n---\n\n".join(
818
  f"[Chunk {i+1} | source: {d.metadata.get('source', '?')} | type: {d.metadata.get('type','text')}]\n{d.page_content}"
819
  for i, d in enumerate(docs)
 
1
  """
2
+ rag_engine.py — Multimodal RAG Engine with Multi-File Support, Reranking & Memory
3
  Supports: PDF, TXT, DOCX, CSV, XLSX, Images (JPG/PNG/WEBP)
4
  Features: Up to 5 simultaneous files, per-file removal, additive indexing
5
  Memory: sliding window of last 6 exchanges
6
 
7
+ KEY CHANGES (v5Cross-Encoder Reranking):
8
+ 1. Cross-encoder reranker (ms-marco-MiniLM-L-6-v2) scores every retrieved
9
+ chunk for true semantic relevance to the query not just embedding distance.
10
+ 2. Over-fetches 12+ candidates from the vectorstore, then reranks to pick
11
+ the top-k most relevant chunks for the LLM context.
12
+ 3. Graceful fallback if the reranker fails to load, uses original order.
13
+
14
+ Previous features preserved:
15
+ - Additive indexing, per-file removal, MAX_FILES=5
16
+ - Multi-file aware generation, cross-doc coverage
17
+ - OCR, color analysis, BLIP raw bytes, VLM descriptions for images
18
+ - Conversation memory (6-exchange sliding window)
19
  """
20
 
21
  import os
 
46
 
47
  # ── Constants ────────────────────────────────────────────────────────────────
48
  EMBED_MODEL = "all-MiniLM-L6-v2"
49
+ RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" # ~80MB, CPU-friendly
50
  CHUNK_SIZE = 600
51
  CHUNK_OVERLAP = 100
52
+ TOP_K = 4 # final chunks sent to LLM after reranking
53
+ RERANK_FETCH_K = 12 # over-fetch this many candidates for reranking
54
  COLLECTION_NAME = "docmind_multimodal"
55
  HF_API_URL = "https://router.huggingface.co/v1/chat/completions"
56
  MEMORY_WINDOW = 6 # number of past Q&A pairs to keep
 
123
  class RAGEngine:
124
  def __init__(self):
125
  self._embeddings: Optional[HuggingFaceEmbeddings] = None
126
+ self._reranker = None # lazy-loaded cross-encoder
127
  self._vectorstore: Optional[Chroma] = None
128
  self._splitter = RecursiveCharacterTextSplitter(
129
  chunk_size=CHUNK_SIZE,
 
145
  )
146
  return self._embeddings
147
 
148
+ @property
149
+ def reranker(self):
150
+ """Lazy-load the cross-encoder reranker (~80MB, CPU-friendly)."""
151
+ if self._reranker is None:
152
+ try:
153
+ from sentence_transformers import CrossEncoder
154
+ logger.info(f"Loading reranker model: {RERANK_MODEL}...")
155
+ self._reranker = CrossEncoder(RERANK_MODEL, max_length=512)
156
+ logger.info("Reranker loaded successfully.")
157
+ except Exception as e:
158
+ logger.warning(f"Failed to load reranker: {e}. Will skip reranking.")
159
+ self._reranker = False # sentinel: don't retry
160
+ return self._reranker if self._reranker is not False else None
161
+
162
+ def _rerank_documents(self, question: str, docs: List[Document], top_k: int) -> List[Document]:
163
+ """Score and reorder documents using the cross-encoder reranker."""
164
+ if not docs:
165
+ return docs
166
+
167
+ ranker = self.reranker
168
+ if ranker is None:
169
+ # Reranker unavailable — fall back to original order
170
+ logger.info("Reranker not available, using original retrieval order.")
171
+ return docs[:top_k]
172
+
173
+ # Build query-document pairs for the cross-encoder
174
+ pairs = [(question, doc.page_content) for doc in docs]
175
+
176
+ try:
177
+ scores = ranker.predict(pairs)
178
+
179
+ # Pair each doc with its rerank score
180
+ scored = list(zip(docs, scores))
181
+ scored.sort(key=lambda x: x[1], reverse=True)
182
+
183
+ reranked = [doc for doc, score in scored[:top_k]]
184
+
185
+ # Log the reranking effect
186
+ original_sources = [d.metadata.get("source", "?")[:30] for d in docs[:top_k]]
187
+ reranked_sources = [d.metadata.get("source", "?")[:30] for d in reranked]
188
+ top_scores = [f"{s:.3f}" for _, s in scored[:top_k]]
189
+ logger.info(
190
+ f"Reranked {len(docs)} candidates → top {top_k}. "
191
+ f"Scores: {top_scores}. "
192
+ f"Before: {original_sources}, After: {reranked_sources}"
193
+ )
194
+
195
+ return reranked
196
+ except Exception as e:
197
+ logger.warning(f"Reranking failed: {e}. Using original order.")
198
+ return docs[:top_k]
199
+
200
  # ── Memory ───────────────────────────────────────────────────────────────
201
 
202
  def clear_memory(self):
 
862
  sources = []
863
 
864
  try:
865
+ # ── Step 1: Over-fetch candidates ────────────────────────────────
866
+ # Retrieve more candidates than needed so the reranker can pick
867
+ # the truly relevant ones. Scale with number of loaded files.
868
+ num_files = len(self._documents)
869
+ fetch_k = max(RERANK_FETCH_K, RERANK_FETCH_K + (num_files - 1) * 2)
870
+ initial_k = fetch_k # MMR will return this many diverse candidates
871
 
872
  retriever = self._vectorstore.as_retriever(
873
  search_type="mmr",
874
+ search_kwargs={"k": initial_k, "fetch_k": fetch_k * 2},
875
  )
876
+ candidate_docs = retriever.invoke(question)
877
+
878
+ # ── Step 2: Rerank with cross-encoder ────────────────────────────
879
+ # The cross-encoder scores each (query, chunk) pair for true
880
+ # semantic relevance — much more accurate than embedding distance.
881
+ final_k = min(TOP_K + num_files - 1, 6)
882
+ docs = self._rerank_documents(question, candidate_docs, top_k=final_k)
883
+
884
  context = "\n\n---\n\n".join(
885
  f"[Chunk {i+1} | source: {d.metadata.get('source', '?')} | type: {d.metadata.get('type','text')}]\n{d.page_content}"
886
  for i, d in enumerate(docs)