Rifqi Hafizuddin commited on
Commit
430c361
·
1 Parent(s): 49feaa9

[KM-564] Edit column description in catalog to reduce token & ingestion time

Browse files
src/catalog/introspect/database.py CHANGED
@@ -130,7 +130,6 @@ class DatabaseIntrospector(BaseIntrospector):
130
  source_id=client_id,
131
  source_type="schema",
132
  name=client.name,
133
- description="",
134
  location_ref=location_ref,
135
  updated_at=datetime.now(UTC),
136
  tables=tables,
@@ -160,6 +159,7 @@ class DatabaseIntrospector(BaseIntrospector):
160
  col["name"],
161
  col.get("is_numeric", False),
162
  row_count,
 
163
  )
164
  except Exception as e:
165
  logger.error(
@@ -177,7 +177,6 @@ class DatabaseIntrospector(BaseIntrospector):
177
  Table(
178
  table_id=_stable_id("t_", table_name),
179
  name=table_name,
180
- description="",
181
  row_count=row_count,
182
  columns=columns,
183
  foreign_keys=foreign_keys,
@@ -218,18 +217,25 @@ class DatabaseIntrospector(BaseIntrospector):
218
  _normalize(v) for v in (profile.get("sample_values") or [])
219
  ] or None
220
 
 
 
 
 
 
221
  column = Column(
222
  column_id=_stable_id("c_", table_name, name),
223
  name=name,
224
  data_type=_map_sql_type(str(col["type"])),
225
- description="",
226
  nullable=True, # nullable not surfaced by extractor; default permissive
227
  pii_flag=False,
228
  sample_values=sample_values,
229
  stats=ColumnStats(
230
  min=_normalize(profile.get("min")),
231
  max=_normalize(profile.get("max")),
 
 
232
  distinct_count=profile.get("distinct_count"),
 
233
  ),
234
  )
235
  if self._pii.detect(column):
 
130
  source_id=client_id,
131
  source_type="schema",
132
  name=client.name,
 
133
  location_ref=location_ref,
134
  updated_at=datetime.now(UTC),
135
  tables=tables,
 
159
  col["name"],
160
  col.get("is_numeric", False),
161
  row_count,
162
+ is_temporal=col.get("is_temporal", False),
163
  )
164
  except Exception as e:
165
  logger.error(
 
177
  Table(
178
  table_id=_stable_id("t_", table_name),
179
  name=table_name,
 
180
  row_count=row_count,
181
  columns=columns,
182
  foreign_keys=foreign_keys,
 
217
  _normalize(v) for v in (profile.get("sample_values") or [])
218
  ] or None
219
 
220
+ top_raw = profile.get("top_values") or []
221
+ top_values: list[Any] | None = [
222
+ _normalize(v) for v, _cnt in top_raw
223
+ ] or None
224
+
225
  column = Column(
226
  column_id=_stable_id("c_", table_name, name),
227
  name=name,
228
  data_type=_map_sql_type(str(col["type"])),
 
229
  nullable=True, # nullable not surfaced by extractor; default permissive
230
  pii_flag=False,
231
  sample_values=sample_values,
232
  stats=ColumnStats(
233
  min=_normalize(profile.get("min")),
234
  max=_normalize(profile.get("max")),
235
+ mean=_normalize(profile.get("mean")),
236
+ median=_normalize(profile.get("median")),
237
  distinct_count=profile.get("distinct_count"),
238
+ top_values=top_values,
239
  ),
240
  )
241
  if self._pii.detect(column):
src/catalog/introspect/tabular.py CHANGED
@@ -141,7 +141,6 @@ class TabularIntrospector(BaseIntrospector):
141
  source_id=document_id,
142
  source_type="tabular",
143
  name=doc.filename,
144
- description="",
145
  location_ref=location_ref,
146
  updated_at=datetime.now(UTC),
147
  tables=tables,
@@ -183,7 +182,6 @@ class TabularIntrospector(BaseIntrospector):
183
  return Table(
184
  table_id=_stable_id("t_", *id_parts),
185
  name=table_name,
186
- description="",
187
  row_count=len(df),
188
  columns=columns,
189
  foreign_keys=[],
@@ -200,7 +198,7 @@ class TabularIntrospector(BaseIntrospector):
200
  (document_id, sheet_name, col_name) if sheet_name else (document_id, col_name)
201
  )
202
 
203
- sample_raw = series.dropna().head(5).tolist()
204
  sample_values: list[Any] | None = [_normalize(v) for v in sample_raw] or None
205
 
206
  is_numeric = pd.api.types.is_numeric_dtype(series)
@@ -212,9 +210,14 @@ class TabularIntrospector(BaseIntrospector):
212
  if distinct_count <= 10
213
  else None
214
  )
 
 
 
215
  stats = ColumnStats(
216
- min=_normalize(non_null.min()) if (is_numeric or is_dt) and len(non_null) > 0 else None,
217
- max=_normalize(non_null.max()) if (is_numeric or is_dt) and len(non_null) > 0 else None,
 
 
218
  distinct_count=distinct_count,
219
  top_values=top_values,
220
  )
@@ -223,7 +226,6 @@ class TabularIntrospector(BaseIntrospector):
223
  column_id=_stable_id("c_", *id_parts),
224
  name=col_name,
225
  data_type=_map_pandas_type(series.dtype),
226
- description="",
227
  nullable=bool(series.isnull().any()),
228
  pii_flag=False,
229
  sample_values=sample_values,
 
141
  source_id=document_id,
142
  source_type="tabular",
143
  name=doc.filename,
 
144
  location_ref=location_ref,
145
  updated_at=datetime.now(UTC),
146
  tables=tables,
 
182
  return Table(
183
  table_id=_stable_id("t_", *id_parts),
184
  name=table_name,
 
185
  row_count=len(df),
186
  columns=columns,
187
  foreign_keys=[],
 
198
  (document_id, sheet_name, col_name) if sheet_name else (document_id, col_name)
199
  )
200
 
201
+ sample_raw = series.dropna().head(3).tolist()
202
  sample_values: list[Any] | None = [_normalize(v) for v in sample_raw] or None
203
 
204
  is_numeric = pd.api.types.is_numeric_dtype(series)
 
210
  if distinct_count <= 10
211
  else None
212
  )
213
+ has_values = len(non_null) > 0
214
+ wants_range = (is_numeric or is_dt) and has_values
215
+ wants_mean = is_numeric and has_values
216
  stats = ColumnStats(
217
+ min=_normalize(non_null.min()) if wants_range else None,
218
+ max=_normalize(non_null.max()) if wants_range else None,
219
+ mean=float(non_null.mean()) if wants_mean else None,
220
+ median=float(non_null.median()) if wants_mean else None,
221
  distinct_count=distinct_count,
222
  top_values=top_values,
223
  )
 
226
  column_id=_stable_id("c_", *id_parts),
227
  name=col_name,
228
  data_type=_map_pandas_type(series.dtype),
 
229
  nullable=bool(series.isnull().any()),
230
  pii_flag=False,
231
  sample_values=sample_values,
src/catalog/models.py CHANGED
@@ -34,6 +34,8 @@ DataType = Literal["int", "decimal", "string", "datetime", "date", "bool", "json
34
  class ColumnStats(BaseModel):
35
  min: Any | None = None
36
  max: Any | None = None
 
 
37
  distinct_count: int | None = None
38
  top_values: list[Any] | None = None
39
 
@@ -42,7 +44,6 @@ class Column(BaseModel):
42
  column_id: str
43
  name: str
44
  data_type: DataType
45
- description: str
46
  nullable: bool
47
  pii_flag: bool = False
48
  sample_values: list[Any] | None = None
@@ -64,7 +65,6 @@ class ForeignKey(BaseModel):
64
  class Table(BaseModel):
65
  table_id: str
66
  name: str
67
- description: str
68
  row_count: int | None = None
69
  columns: list[Column]
70
  foreign_keys: list[ForeignKey] = Field(default_factory=list)
@@ -74,7 +74,6 @@ class Source(BaseModel):
74
  source_id: str
75
  source_type: SourceType
76
  name: str
77
- description: str
78
  location_ref: str
79
  updated_at: datetime
80
  tables: list[Table] = Field(default_factory=list)
 
34
  class ColumnStats(BaseModel):
35
  min: Any | None = None
36
  max: Any | None = None
37
+ mean: float | None = None
38
+ median: float | None = None
39
  distinct_count: int | None = None
40
  top_values: list[Any] | None = None
41
 
 
44
  column_id: str
45
  name: str
46
  data_type: DataType
 
47
  nullable: bool
48
  pii_flag: bool = False
49
  sample_values: list[Any] | None = None
 
65
  class Table(BaseModel):
66
  table_id: str
67
  name: str
 
68
  row_count: int | None = None
69
  columns: list[Column]
70
  foreign_keys: list[ForeignKey] = Field(default_factory=list)
 
74
  source_id: str
75
  source_type: SourceType
76
  name: str
 
77
  location_ref: str
78
  updated_at: datetime
79
  tables: list[Table] = Field(default_factory=list)
src/catalog/render.py CHANGED
@@ -8,13 +8,18 @@ from .models import Source
8
  def render_source(source: Source) -> str:
9
  """Render a Source as the canonical text block consumed by the planner.
10
 
11
- Includes stable IDs (so the LLM can echo them back), per-column data
12
- type, sample values (or `PII (suppressed)` for flagged columns), basic
13
- stats, and resolved-by-name foreign keys.
 
 
 
 
 
 
14
  """
15
  lines: list[str] = [
16
  f"Source: {source.name} ({source.source_type})",
17
- f"Source ID: {source.source_id}",
18
  "",
19
  "Tables:",
20
  ]
@@ -26,9 +31,9 @@ def render_source(source: Source) -> str:
26
 
27
  for table in source.tables:
28
  rc = table.row_count
29
- rc_str = f"({rc:,} rows) " if rc is not None else ""
30
  lines.append("")
31
- lines.append(f" Table: {table.name} {rc_str}— id={table.table_id}")
32
  lines.append(" Columns:")
33
  for col in table.columns:
34
  samples = "PII (suppressed)" if col.pii_flag else (col.sample_values or [])
@@ -38,12 +43,17 @@ def render_source(source: Source) -> str:
38
  stats_parts.append(f"min={col.stats.min}")
39
  if col.stats.max is not None:
40
  stats_parts.append(f"max={col.stats.max}")
 
 
 
 
41
  if col.stats.distinct_count is not None:
42
  stats_parts.append(f"distinct={col.stats.distinct_count}")
 
 
43
  stats_str = (", " + ", ".join(stats_parts)) if stats_parts else ""
44
  lines.append(
45
- f" - {col.name} [{col.data_type}]: samples={samples}{stats_str} "
46
- f"— id={col.column_id}"
47
  )
48
  if table.foreign_keys:
49
  lines.append(" Foreign keys:")
 
8
  def render_source(source: Source) -> str:
9
  """Render a Source as the canonical text block consumed by the planner.
10
 
11
+ Identifiers (source_id / table_id / column_id) are intentionally NOT
12
+ rendered the LLM references things by name, and the IR resolver maps
13
+ names back to stable IDs before validation. This saves ~10% input tokens
14
+ per planner call.
15
+
16
+ Columns show data type, sample values (or `PII (suppressed)`), and
17
+ populated stats only (min/max suppressed for string/bool, where they're
18
+ useless). Top values are listed when available for low-cardinality cols.
19
+ Foreign keys are resolved to names.
20
  """
21
  lines: list[str] = [
22
  f"Source: {source.name} ({source.source_type})",
 
23
  "",
24
  "Tables:",
25
  ]
 
31
 
32
  for table in source.tables:
33
  rc = table.row_count
34
+ rc_str = f" ({rc:,} rows)" if rc is not None else ""
35
  lines.append("")
36
+ lines.append(f" Table: {table.name}{rc_str}")
37
  lines.append(" Columns:")
38
  for col in table.columns:
39
  samples = "PII (suppressed)" if col.pii_flag else (col.sample_values or [])
 
43
  stats_parts.append(f"min={col.stats.min}")
44
  if col.stats.max is not None:
45
  stats_parts.append(f"max={col.stats.max}")
46
+ if col.stats.mean is not None:
47
+ stats_parts.append(f"mean={col.stats.mean:.4g}")
48
+ if col.stats.median is not None:
49
+ stats_parts.append(f"median={col.stats.median:.4g}")
50
  if col.stats.distinct_count is not None:
51
  stats_parts.append(f"distinct={col.stats.distinct_count}")
52
+ if col.stats.top_values:
53
+ stats_parts.append(f"top={col.stats.top_values}")
54
  stats_str = (", " + ", ".join(stats_parts)) if stats_parts else ""
55
  lines.append(
56
+ f" - {col.name} [{col.data_type}]: samples={samples}{stats_str}"
 
57
  )
58
  if table.foreign_keys:
59
  lines.append(" Foreign keys:")
src/pipeline/db_pipeline/extractor.py CHANGED
@@ -9,7 +9,7 @@ not user input.
9
  from typing import Optional
10
 
11
  import pandas as pd
12
- from sqlalchemy import Float, Integer, Numeric, inspect
13
  from sqlalchemy.engine import Engine
14
 
15
  from src.middlewares.logging import get_logger
@@ -17,10 +17,16 @@ from src.middlewares.logging import get_logger
17
  logger = get_logger("db_extractor")
18
 
19
  TOP_VALUES_THRESHOLD = 0.05 # show top values if distinct_ratio <= 5%
 
 
 
 
 
 
20
 
21
  # Dialects where PERCENTILE_CONT(...) WITHIN GROUP is supported as an aggregate.
22
  # MySQL has no percentile aggregate; BigQuery has PERCENTILE_CONT only as an
23
- # analytic (window) function — both drop median and keep min/max/mean.
24
  _MEDIAN_DIALECTS = frozenset({"postgresql", "mssql", "snowflake"})
25
 
26
 
@@ -53,7 +59,7 @@ def _qi(engine: Engine, name: str) -> str:
53
  def get_schema(
54
  engine: Engine, exclude_tables: Optional[frozenset[str]] = None
55
  ) -> dict[str, list[dict]]:
56
- """Returns {table_name: [{name, type, is_numeric, is_primary_key, foreign_key}, ...]}."""
57
  exclude = exclude_tables or frozenset()
58
  inspector = inspect(engine)
59
  schema = {}
@@ -75,6 +81,7 @@ def get_schema(
75
  "name": c["name"],
76
  "type": str(c["type"]),
77
  "is_numeric": isinstance(c["type"], (Integer, Numeric, Float)),
 
78
  "is_primary_key": c["name"] in pk_cols,
79
  "foreign_key": fk_map.get(c["name"]),
80
  }
@@ -96,8 +103,14 @@ def profile_column(
96
  col_name: str,
97
  is_numeric: bool,
98
  row_count: int,
 
99
  ) -> dict:
100
- """Returns null_count, distinct_count, min/max, top values, and sample values."""
 
 
 
 
 
101
  if row_count == 0:
102
  return {
103
  "null_count": 0,
@@ -108,39 +121,69 @@ def profile_column(
108
 
109
  qt = _qi(engine, table_name)
110
  qc = _qi(engine, col_name)
 
 
 
 
 
111
 
112
- # Combined stats query: null_count, distinct_count, and min/max (if numeric).
113
- # One round-trip instead of two.
114
- select_cols = [
115
  f"COUNT(*) - COUNT({qc}) AS nulls",
116
  f"COUNT(DISTINCT {qc}) AS distincts",
117
  ]
118
- if is_numeric:
119
- select_cols.append(f"MIN({qc}) AS min_val")
120
- select_cols.append(f"MAX({qc}) AS max_val")
121
- select_cols.append(f"AVG({qc}) AS mean_val")
122
- if _supports_median(engine):
123
- select_cols.append(
124
- f"PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY {qc}) AS median_val"
125
- )
126
- stats = pd.read_sql(f"SELECT {', '.join(select_cols)} FROM {qt}", engine)
127
-
128
- null_count = int(stats.iloc[0]["nulls"])
129
- distinct_count = int(stats.iloc[0]["distincts"])
130
- distinct_ratio = distinct_count / row_count if row_count > 0 else 0
131
 
132
- profile = {
133
- "null_count": null_count,
134
- "distinct_count": distinct_count,
135
- "distinct_ratio": round(distinct_ratio, 4),
136
- }
137
-
138
- if is_numeric:
139
- profile["min"] = stats.iloc[0]["min_val"]
140
- profile["max"] = stats.iloc[0]["max_val"]
141
- profile["mean"] = stats.iloc[0]["mean_val"]
142
- if _supports_median(engine):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  profile["median"] = stats.iloc[0]["median_val"]
 
 
 
 
 
 
 
 
144
 
145
  if 0 < distinct_ratio <= TOP_VALUES_THRESHOLD:
146
  top_sql = _head_query(
@@ -153,9 +196,6 @@ def profile_column(
153
  top = pd.read_sql(top_sql, engine)
154
  profile["top_values"] = list(zip(top.iloc[:, 0].tolist(), top["cnt"].tolist()))
155
 
156
- sample = pd.read_sql(_head_query(engine, qc, qt, 5), engine)
157
- profile["sample_values"] = sample.iloc[:, 0].tolist()
158
-
159
  return profile
160
 
161
 
@@ -273,7 +313,8 @@ def build_text(table_name: str, row_count: int, col: dict, profile: dict) -> str
273
  text += f"Distinct count: {profile['distinct_count']} ({profile['distinct_ratio']:.1%})\n"
274
  if "min" in profile:
275
  text += f"Min: {profile['min']}, Max: {profile['max']}\n"
276
- text += f"Mean: {profile['mean']}\n"
 
277
  if profile.get("median") is not None:
278
  text += f"Median: {profile['median']}\n"
279
  if "top_values" in profile:
 
9
  from typing import Optional
10
 
11
  import pandas as pd
12
+ from sqlalchemy import Date, DateTime, Float, Integer, Numeric, inspect
13
  from sqlalchemy.engine import Engine
14
 
15
  from src.middlewares.logging import get_logger
 
17
  logger = get_logger("db_extractor")
18
 
19
  TOP_VALUES_THRESHOLD = 0.05 # show top values if distinct_ratio <= 5%
20
+ SAMPLE_LIMIT = 3 # sample N rows per column (down from 5 — token cost)
21
+
22
+ # Dialects with a single-statement CTE that survives `pd.read_sql`. On these we
23
+ # fold the stats and sample queries into one round-trip per column. MySQL <8 and
24
+ # old SQLite are excluded out of caution.
25
+ _CTE_DIALECTS = frozenset({"postgresql", "mssql", "snowflake", "bigquery"})
26
 
27
  # Dialects where PERCENTILE_CONT(...) WITHIN GROUP is supported as an aggregate.
28
  # MySQL has no percentile aggregate; BigQuery has PERCENTILE_CONT only as an
29
+ # analytic (window) function — both drop median and keep mean.
30
  _MEDIAN_DIALECTS = frozenset({"postgresql", "mssql", "snowflake"})
31
 
32
 
 
59
  def get_schema(
60
  engine: Engine, exclude_tables: Optional[frozenset[str]] = None
61
  ) -> dict[str, list[dict]]:
62
+ """Returns {table_name: [{name, type, is_numeric, is_temporal, is_primary_key, foreign_key}, ...]}."""
63
  exclude = exclude_tables or frozenset()
64
  inspector = inspect(engine)
65
  schema = {}
 
81
  "name": c["name"],
82
  "type": str(c["type"]),
83
  "is_numeric": isinstance(c["type"], (Integer, Numeric, Float)),
84
+ "is_temporal": isinstance(c["type"], (Date, DateTime)),
85
  "is_primary_key": c["name"] in pk_cols,
86
  "foreign_key": fk_map.get(c["name"]),
87
  }
 
103
  col_name: str,
104
  is_numeric: bool,
105
  row_count: int,
106
+ is_temporal: bool = False,
107
  ) -> dict:
108
+ """Returns null_count, distinct_count, min/max (numeric+temporal), mean/median (numeric), and sample values.
109
+
110
+ Numeric columns compute mean and (where the dialect supports it) median.
111
+ Datetime/date get min/max only (no useful mean/median over timestamps).
112
+ Strings/bools skip range stats entirely.
113
+ """
114
  if row_count == 0:
115
  return {
116
  "null_count": 0,
 
121
 
122
  qt = _qi(engine, table_name)
123
  qc = _qi(engine, col_name)
124
+ wants_range = is_numeric or is_temporal
125
+ wants_mean = is_numeric
126
+ wants_median = is_numeric and _supports_median(engine)
127
+
128
+ profile: dict = {}
129
 
130
+ # Build the stats SELECT list incrementally same column set used in both
131
+ # the CTE and fallback branches.
132
+ stat_cols = [
133
  f"COUNT(*) - COUNT({qc}) AS nulls",
134
  f"COUNT(DISTINCT {qc}) AS distincts",
135
  ]
136
+ if wants_range:
137
+ stat_cols += [f"MIN({qc}) AS min_val", f"MAX({qc}) AS max_val"]
138
+ if wants_mean:
139
+ stat_cols.append(f"AVG({qc}) AS mean_val")
140
+ if wants_median:
141
+ stat_cols.append(
142
+ f"PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY {qc}) AS median_val"
143
+ )
 
 
 
 
 
144
 
145
+ if engine.dialect.name in _CTE_DIALECTS:
146
+ # Single round-trip: stats + sample together via CTE.
147
+ stats_select = ", ".join(stat_cols)
148
+ passthrough = ", ".join(
149
+ f"s.{c.split(' AS ')[-1]}" for c in stat_cols
150
+ )
151
+ sql = (
152
+ f"WITH stats AS (SELECT {stats_select} FROM {qt}), "
153
+ f"sample AS ({_head_query(engine, qc + ' AS sample_val', qt, SAMPLE_LIMIT)}) "
154
+ f"SELECT {passthrough}, sample.sample_val FROM stats s CROSS JOIN sample"
155
+ )
156
+ rows = pd.read_sql(sql, engine)
157
+ null_count = int(rows.iloc[0]["nulls"])
158
+ distinct_count = int(rows.iloc[0]["distincts"])
159
+ sample_values = rows["sample_val"].tolist()
160
+ if wants_range:
161
+ profile["min"] = rows.iloc[0]["min_val"]
162
+ profile["max"] = rows.iloc[0]["max_val"]
163
+ if wants_mean:
164
+ profile["mean"] = rows.iloc[0]["mean_val"]
165
+ if wants_median:
166
+ profile["median"] = rows.iloc[0]["median_val"]
167
+ else:
168
+ # Two-query fallback (MySQL/SQLite).
169
+ stats = pd.read_sql(f"SELECT {', '.join(stat_cols)} FROM {qt}", engine)
170
+ null_count = int(stats.iloc[0]["nulls"])
171
+ distinct_count = int(stats.iloc[0]["distincts"])
172
+ if wants_range:
173
+ profile["min"] = stats.iloc[0]["min_val"]
174
+ profile["max"] = stats.iloc[0]["max_val"]
175
+ if wants_mean:
176
+ profile["mean"] = stats.iloc[0]["mean_val"]
177
+ if wants_median:
178
  profile["median"] = stats.iloc[0]["median_val"]
179
+ sample = pd.read_sql(_head_query(engine, qc, qt, SAMPLE_LIMIT), engine)
180
+ sample_values = sample.iloc[:, 0].tolist()
181
+
182
+ distinct_ratio = distinct_count / row_count if row_count > 0 else 0
183
+ profile["null_count"] = null_count
184
+ profile["distinct_count"] = distinct_count
185
+ profile["distinct_ratio"] = round(distinct_ratio, 4)
186
+ profile["sample_values"] = sample_values
187
 
188
  if 0 < distinct_ratio <= TOP_VALUES_THRESHOLD:
189
  top_sql = _head_query(
 
196
  top = pd.read_sql(top_sql, engine)
197
  profile["top_values"] = list(zip(top.iloc[:, 0].tolist(), top["cnt"].tolist()))
198
 
 
 
 
199
  return profile
200
 
201
 
 
313
  text += f"Distinct count: {profile['distinct_count']} ({profile['distinct_ratio']:.1%})\n"
314
  if "min" in profile:
315
  text += f"Min: {profile['min']}, Max: {profile['max']}\n"
316
+ if profile.get("mean") is not None:
317
+ text += f"Mean: {profile['mean']}\n"
318
  if profile.get("median") is not None:
319
  text += f"Median: {profile['median']}\n"
320
  if "top_values" in profile: