from web_apps.rag.extractor.extractor_base import BaseExtractor
from langchain_core.documents import Document
import logging
from typing import Optional
import pandas as pd
from openpyxl.reader.excel import load_workbook
from pandas import DataFrame
from web_apps.rag.extractor.table_utils import split_to_multi_document_tables, extract_tables_by_chunk_max_token
_punctuation = r"""!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~!“”‘’(),。:;、?》《」「【】"""
sub_theme_name = 'sub_theme_name'
def _convert_to_html(table: list) -> str:
html = "<table>"
for row in table:
html += "<tr>"
for cell in row:
html += f"<td>{cell}</td>"
html += "</tr>"
html += "</table>"
return html
def _handler_merged_cell_df(sheet, df: DataFrame, header=None) -> DataFrame:
'''
处理合并单元格。将合并单元格内的值设置为合并单元格的值(即start_cell的值)
:param sheet:
:param df:
:param header:
:return:
'''
for item in sheet.merged_cells:
top_col, top_row, bottom_col, bottom_row = item.bounds
base_value = item.start_cell.value
top_row -= 1
top_col -= 1
if header is not None:
top_row -= header + 1
bottom_row -= header + 1
if top_row < 0 or bottom_row < 0:
continue
df.iloc[top_row:bottom_row, top_col:bottom_col] = base_value
return df
def _find_df_real_header(df: DataFrame, start_header=0, max_header=2):
tmp_header = start_header
headers = [hd for hd in df.columns.tolist() if "Unnamed:" not in hd]
if len(headers) == len(df.columns) or start_header >= max_header:
return tmp_header
else:
tmp_header = tmp_header + 1
for row in df.values.tolist():
tmp_row = [r for r in row if not pd.isna(r)]
if len(tmp_row) < len(df.columns) and tmp_header < max_header:
tmp_header = tmp_header + 1
else:
break
return tmp_header
def _parse_df_header_and_values(df: DataFrame):
headers = df.columns.tolist()
values = df.values.tolist()
datas = []
for row_idx, row in enumerate(values):
row_data = []
for col_idx, col in enumerate(row):
col_value = str(col)
if col_value == 'nan':
col_value = ''
row_data.append(col_value.strip())
datas.append(row_data)
return headers, datas
def _find_sheet_real_max_rows_columns(sheet):
sheet_max_row = min(sheet.max_row, int("100000"))
sheet_max_column = min(sheet.max_column, int("100"))
tolerance_continue_row_blank = int("20")
tolerance_continue_column_blank = int("10")
def _invalid_value(v):
return (v is None
or str(v).strip() == ''
or str(v).strip() == 'nan'
or (len(str(v).strip()) == 1 and str(v).strip() in _punctuation))
max_row = 0
max_columns = []
continue_row_blank = 0
for i in range(1, sheet_max_row + 1):
continue_col_blank = 0
row_max_column = 0
for j in range(1, sheet_max_column + 1):
cell = sheet.cell(row=i, column=j)
if _invalid_value(cell.value):
continue_col_blank = continue_col_blank + 1
else:
row_max_column = j
continue_col_blank = 0
if continue_col_blank >= tolerance_continue_column_blank:
break
max_columns.append(row_max_column)
if row_max_column == 0:
continue_row_blank = continue_row_blank + 1
else:
max_row = i
continue_row_blank = 0
if continue_row_blank >= tolerance_continue_row_blank:
break
max_column = max(max_columns)
logging.info(f"sheet({sheet.title}):[{sheet.max_row},{sheet.max_column},real:[{max_row},{max_column}]")
return max_row, max_column
def read_xlsx(file):
excel = pd.ExcelFile(load_workbook(file), engine="openpyxl")
for sheet_name in excel.sheet_names:
sheet = excel.book[sheet_name]
find_max_row, find_max_column = _find_sheet_real_max_rows_columns(sheet)
header = 0
df = excel.parse(sheet_name, header=header, nrows=find_max_row, usecols=range(0, find_max_column))
max_header = int("6")
header = _find_df_real_header(df, start_header=header, max_header=max_header)
if header > 0:
df = excel.parse(sheet_name, header=header, nrows=find_max_row, usecols=range(0, find_max_column))
df = _handler_merged_cell_df(sheet, df, header=header)
headers, values = _parse_df_header_and_values(df)
yield sheet_name, headers, values
class ExcelExtractor(BaseExtractor):
"""Load Excel files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False
):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def extract(self) -> list[Document]:
""" Load from Excel file in xls or xlsx format using Pandas."""
documents = []
for sheet_name, headers, values in read_xlsx(self._file_path):
useful_columns = 0
for idx, header in enumerate(headers):
if str(header).startswith('Unnamed:'):
useful_columns = idx
break
if not useful_columns:
useful_columns = len(headers)
headers = headers[:useful_columns]
values = [data[:useful_columns] for data in values]
array_2d_list_with_headers = [headers] + values
table_htmls = extract_tables_by_chunk_max_token(table_array=array_2d_list_with_headers,
table_max_token=1024) + '\n'
sheet_documents = split_to_multi_document_tables(text=table_htmls, metadata={sub_theme_name: sheet_name})
documents.extend(sheet_documents)
return documents
if __name__ == '__main__':
extractor = ExcelExtractor('test.xlsx')
documents = extractor.extract()
print(documents)