์ค์ฉ์ Paper Reading Assistant 2
Contents
์ค์ฉ์ Paper Reading Assistant 2ยถ
0. Remindยถ
๋ ผ๋ฌธ๊ณผ ๊ด๋ จ๋ ์ง๋ฌธ์ ์์ฑํ๊ณ ๋์ ๋ต๋ณ์ ์ฑ์ ํ ์ ์๋ ์์คํ ๊ฐ๋ฐ
method
Download paper: Arxiv์์ ๋ ผ๋ฌธ ์ ๋ณด์ pdf ๋ค์ด๋ก๋
Preprocess paper: pdf์์ ๋ ผ๋ฌธ text๋ฅผ ์ถ์ถ ํ ์ ์ฒ๋ฆฌ
Generate questions: ๋ ผ๋ฌธ์ ์ฝ๊ณ ๊ด๋ จ๋ ์ง๋ฌธ ์์ฑ
Answer questions: ์ง๋ฌธ์ ๋ํ pseudo ์ ๋ต ์์ฑ
Evaluate user answer: pseudo ์ ๋ต์ ๊ธฐ์ค์ผ๋ก ์ฌ๋์ ๋ต๋ณ ์ฑ์
1. Problemsยถ
์ค์ํ์ง ์์ ํ ์คํธ๊ฐ ๋ง๋ค. (reference, table caption ๋ฑ) โ PDF layout analysis
์ฑ์ ์ ์ฌ์ฉํ๋ BERTScore ๋ชจ๋ธ์ ํฌ๊ธฐ๊ฐ ํฌ๋ค. โ Model Distillation
2. PDF layout analysisยถ
2.1. Introductionยถ
๋ ผ๋ฌธ์ ๋์ฒด๋ก ์๋ก , ๊ด๋ จ ์ฐ๊ตฌ, ๋ฐฉ๋ฒ๋ก , ์คํ, ๊ฒฐ๋ก ๊ณผ ๊ฐ์ ๊ตฌ์กฐ๋ก ๊ตฌ์ฑ๋์ดย ์๋ค.
ํ์ง๋ง PDF๋ก ๋ฐฐํฌ๋ ๋ ผ๋ฌธ์์ ์์ ๊ฐ์ ๊ตฌ์กฐ๋ฅผ ์๋์ผ๋ก ํ์ ํ๋ ๊ฒ์ ์ฝ์ง ์๋ค.
์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋ค์ํ Document-Image Understanding ๋ชจ๋ธ๋ค์ด ์ ์๋์๋ค.
๊ทธ ์คย VIsualย LAyout (VILA) ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ๋ ผ๋ฌธ์ ๊ตฌ์กฐ๋ฅผ ์ถ์ถํ๋ ๊ณผ์ ์ ์งํํด๋ณด์๋ค.
2.2. VILA: Visual LAyoutยถ
VILA: Improving structured content extraction from scientific PDFs using visual layout groups
๋ฌธ์์ ๊ตฌ์กฐ๋ฅผ ์ธ์ํ๋ Document Layout Analysis ๋ฌธ์ ๋ ์ฃผ๋ก token classification (NLP-centric)์ด๋ object detection (vision-centric) ๋ฌธ์ ๋ก ์นํํ์ฌ ํด๊ฒฐํ๋ค.
VILA์ ๊ฒฝ์ฐ token classification ๋ฐฉ์์ ์ฌ์ฉํ๋ค.
VILA๋ line ๋๋ block๊ณผ ๊ฐ์ visual group ๋ด์ token๋ค์ ๊ฐ์ ๋ผ๋ฒจ์ ๊ฐ์ง๋ค๋ โgroup uniformity assumptionโ์ ๊ฐ์กฐํ๋ค.
group uniformity assumption๋ฅผ ๋ฐ๋ฅด๊ธฐ ์ํด ๋ ๊ฐ์ง ๋ฐฉ๋ฒ์ ์ ์ํ๋ค.
I-VILA: group ์ฌ์ด์ speical token [BLK] ์ ๋ ฅ
H-VILA: group ๋ณ๋ก self-attention ํ group representation ์ถ์ถํ์ฌ group ๊ฐ self-attention ์งํ
์ฑ๋ฅ์ I-VILA๊ฐ ๋ฐ์ด๋์ง๋ง, ํจ์จ์ฑ์ H-VILA๊ฐ ๋ ์ข๋ค.
Illustration of the H-VILA [1]
2.3. Codeยถ
VILA๋ฅผ Docbank ๋ฐ์ดํฐ์ ์ผ๋ก fine-tuningํ ๋ชจ๋ธ ์ฌ์ฉ [2]
VILA official repository์์ ๊ณต๊ฐํ ์ฝ๋์ ๋ชจ๋ธ ์ฌ์ฉ
2.3.1. Setup
clone repsoitory
!git clone https://github.com/allenai/vila.git
import sys
sys.path.append('vila/src')
import libraries
import layoutparser as lp
from collections import defaultdict
from utils import download_paper
from vila.pdftools.pdf_extractor import PDFExtractor
from vila.predictors import HierarchicalPDFPredictor, LayoutIndicatorPDFPredictor
2.3.2. Modules
predict: pdf๋ฅผ ์ฝ์ ๋ค token classification ์งํ
construct_token_groups: ๊ฐ์ class๋ก ์์ธก๋ token๋ผ๋ฆฌ ๊ทธ๋ฃนํ
join_group_text: ๊ฐ์ group์ token๋ค์ ํ๋์ text๋ก ๋ฌถ์
token์ bbox๋ฅผ ๊ธฐ์ค์ผ๋ก ๋์ด์ฐ๊ธฐ ์ ๋ฌด ๊ฒฐ์
construct_section_groups: section(์๋ก , ๋ณธ๋ก , ๊ฒฐ๋ก ๋ฑ)๊ณผ section์ ํด๋นํ๋ paragraph ์ถ์ถ
def predict(pdf_path, pdf_extractor, vision_model, layout_model):
page_tokens, page_images = pdf_extractor.load_tokens_and_image(pdf_path)
pred_tokens = []
for page_token, page_image in zip(page_tokens, page_images):
blocks = vision_model.detect(page_image)
page_token.annotate(blocks=blocks)
pdf_data = page_token.to_pagedata().to_dict()
pred_tokens += layout_model.predict(pdf_data, page_token.page_size)
return pred_tokens
def construct_token_groups(pred_tokens):
groups, group, group_type, prev_bbox = [], [], None, None
for token in pred_tokens:
if group_type is None:
is_continued = True
elif token.type == group_type:
if group_type == 'section':
is_continued = abs(prev_bbox[3] - token.coordinates[3]) < 1.
else:
is_continued = True
else:
is_continued = False
# print(token.text, token.type, is_continued)
group_type = token.type
prev_bbox = token.coordinates
if is_continued:
group.append(token)
else:
groups.append(group)
group = [token]
if group:
groups.append(group)
return groups
def join_group_text(group):
text = ''
prev_bbox = None
for token in group:
if not text:
text += token.text
else:
if abs(prev_bbox[2] - token.coordinates[0]) > 2:
text += ' ' + token.text
else:
text += token.text
prev_bbox = token.coordinates
return text
def construct_section_groups(token_groups):
section_groups = defaultdict(list)
section = None
for group in token_groups:
group_type = group[0].type
group_text = join_group_text(group)
if group_type == 'section':
section = group_text
section_groups[section]
elif group_type == 'paragraph' and section is not None:
section_groups[section].append(group_text)
section_groups = {k: ' '.join(v) for k,v in section_groups.items()}
return section_groups
2.3.3. Run
prepare models
pdf_extractor = PDFExtractor("pdfplumber")
vision_model = lp.EfficientDetLayoutModel("lp://PubLayNet")
layout_model = HierarchicalPDFPredictor.from_pretrained("allenai/hvila-row-layoutlm-finetuned-docbank")
inference
pdf_path = '2307.03170v1.pdf'
pred_tokens = predict(pdf_path, pdf_extractor, vision_model, layout_model)
token_groups = construct_token_groups(pred_tokens)
section_groups = construct_section_groups(token_groups)
2.3.4. Results
section ๋ชฉ๋ก
sections = list(section_groups.keys())
print(sectiosn)
section text
print(section_groups['6 Limitations and future work'])
3. BERTScore Distillationยถ
3.1. Introductionยถ
BERTScore๋ pretrained language model์ ์ฌ์ฉํ์ฌ ๋ ๋ฌธ์ฅ์ ์ ์ฌ๋๋ฅผ ์ธก์ ํ๋ ๋ฐฉ๋ฒ์ด๋ค. ์ฃผ๋ก ๋ฒ์ญ, ์์ฝ ๋ฑ ๋ฌธ์ฅ ์์ฑ ๋ชจ๋ธ์ ํ๊ฐํ๋ ๋ฐ ์ฌ์ฉํ๋ค [3].
language model์ ํฌ๊ธฐ๊ฐ ํด์๋ก BERTScore์ Human evalution์ ์๊ด ๊ด๊ณ๊ฐ ํฐ ๊ฒฝํฅ์ด ์๋ค.
ํ์ง๋ง ํฐ ๋ชจ๋ธ์ ์ดํ๋ฆฌ์ผ์ด์ ์์ ์ค์๊ฐ์ผ๋ก ์ฌ์ฉ๋๊ธฐ ์ด๋ ต๋ค๋ ๋จ์ ์ด ์๋ค.
์ด๋ฅผ ํด๊ฒฐํ๊ณ ์ Knowledge distillation์ ํตํด ์์ ๋ชจ๋ธ์ด ํฐ ๋ชจ๋ธ์ BERTScore๋ฅผ ๋ฐ๋ผํ๋๋ก ํ์ต์์ผฐ๋ค.
๊ฒฐ๊ณผ ๋ชจ๋ธ: yongsun-yoon/minilmv2-bertscore-distilled
๋ชจ๋ธ๋ณ BERTScore์ Human evaluation๊ณผ์ ์๊ด๊ด๊ณ [4]
3.2. Setupยถ
student model์ ๊ฒฝ๋ํ๋ ๋ชจ๋ธ nreimers/MiniLMv2-L6-H384-distilled-from-RoBERTa-Large๋ฅผ ์ฌ์ฉ
teacher model์ 3์์ ๋ญํฌ๋ microsoft/deberta-large-mnli ์ฌ์ฉ
import math
import wandb
import easydict
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn.functional as F
import huggingface_hub
from bert_score import BERTScorer
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
cfg = easydict.EasyDict(
device = 'cuda:0',
student_name = 'nreimers/MiniLMv2-L6-H384-distilled-from-RoBERTa-Large',
teacher_name = 'microsoft/deberta-large-mnli',
teacher_layer_idx = 18,
lr = 5e-5,
batch_size = 8,
num_epochs = 5
)
3.3. Dataยถ
์ด๋์ ๋ ์ ์ฌ์ฑ์ด ์๋ ๋ฌธ์ฅ์์ ์ฌ์ฉํ๊ธฐ ์ํด GLUE MNLI ๋ฐ์ดํฐ์ ์ ํ
class Dataset(torch.utils.data.Dataset):
def __init__(self, data, text1_key, text2_key):
self.data = data
self.text1_key = text1_key
self.text2_key = text2_key
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
text1 = item[self.text1_key]
text2 = item[self.text2_key]
return text1, text2
def collate_fn(self, batch):
texts1, texts2 = zip(*batch)
return texts1, texts2
def get_dataloader(self, batch_size, shuffle):
return torch.utils.data.DataLoader(self, batch_size=batch_size, shuffle=shuffle)
data = load_dataset('glue', 'mnli')
train_data = data['train']
train_data = train_data.train_test_split(train_size=80000)['train']
train_dataset = Dataset(train_data, 'premise', 'hypothesis')
train_loader = train_dataset.get_dataloader(cfg.batch_size, True)
test_data = data['validation_mismatched'].train_test_split(test_size=4000)['test']
test_dataset = Dataset(test_data, 'premise', 'hypothesis')
test_loader = test_dataset.get_dataloader(cfg.batch_size, False)
3.4. Modelยถ
teacher_tokenizer = AutoTokenizer.from_pretrained(cfg.teacher_name)
teacher_model = AutoModel.from_pretrained(cfg.teacher_name)
_ = teacher_model.eval().requires_grad_(False).to(cfg.device)
student_tokenizer = AutoTokenizer.from_pretrained(cfg.student_name)
student_model = AutoModel.from_pretrained(cfg.student_name)
_ = student_model.train().to(cfg.device)
optimizer = torch.optim.Adam(student_model.parameters(), lr=cfg.lr)
3.5. Trainยถ
๋ ๋ฌธ์ฅ์ cross attention score๋ฅผ ๊ณ์ฐํ ๋ค teacher์ attention ๋ถํฌ๋ฅผ student๊ฐ ๋ฐ๋ผํ๋๋ก ํ์ต
loss function์ผ๋ก ๋ ๋ถํฌ๊ฐ์ ์ฐจ์ด๋ฅผ ๊ณ์ฐํ๋ kl divergence๋ฅผ ์ฌ์ฉ
์ด๋ teacher model๊ณผ student model์ tokenizer๊ฐ ๋ค๋ฅผ ๊ฒฝ์ฐ token ๋จ์์ ๋น๊ต๊ฐ ๋ถ๊ฐ๋ฅํ๋ค. ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด token์ word ๋จ์๋ก ๋ณํํ๋ค.
def get_word_embeds(model, tokenizer, texts, layer_idx=-1, max_length=384):
inputs = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to(model.device)
outputs = model(**inputs, output_hidden_states=True)
num_texts = inputs.input_ids.size(0)
token_embeds = outputs.hidden_states[layer_idx]
batch_word_embeds = []
for i in range(num_texts):
text_word_embeds = []
j = 0
while True:
token_span = inputs.word_to_tokens(i, j)
if token_span is None: break
word_embed = token_embeds[i][token_span.start:token_span.end].mean(dim=0)
text_word_embeds.append(word_embed)
j += 1
text_word_embeds = torch.stack(text_word_embeds, dim=0).unsqueeze(0) # (1, seq_length, hidden_dim)
batch_word_embeds.append(text_word_embeds)
return batch_word_embeds
def kl_div_loss(s, t, temperature):
if len(s.size()) != 2:
s = s.view(-1, s.size(-1))
t = t.view(-1, t.size(-1))
s = F.log_softmax(s / temperature, dim=-1)
t = F.softmax(t / temperature, dim=-1)
return F.kl_div(s, t, reduction='batchmean') * (temperature ** 2)
def transpose_for_scores(h, num_heads):
batch_size, seq_length, dim = h.size()
head_size = dim // num_heads
h = h.view(batch_size, seq_length, num_heads, head_size)
return h.permute(0, 2, 1, 3) # (batch, num_heads, seq_length, head_size)
def attention(h1, h2, num_heads, attention_mask=None):
# assert h1.size() == h2.size()
head_size = h1.size(-1) // num_heads
h1 = transpose_for_scores(h1, num_heads) # (batch, num_heads, seq_length, head_size)
h2 = transpose_for_scores(h2, num_heads) # (batch, num_heads, seq_length, head_size)
attn = torch.matmul(h1, h2.transpose(-1, -2)) # (batch_size, num_heads, seq_length, seq_length)
attn = attn / math.sqrt(head_size)
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1 - attention_mask) * -10000.0
attn = attn + attention_mask
return attn
def train_epoch(
teacher_model, teacher_tokenizer,
student_model, student_tokenizer,
train_loader,
teacher_layer_idx,
):
student_model.train()
pbar = tqdm(train_loader)
for texts1, texts2 in pbar:
teacher_embeds1 = get_word_embeds(teacher_model, teacher_tokenizer, texts1, layer_idx=teacher_layer_idx)
teacher_embeds2 = get_word_embeds(teacher_model, teacher_tokenizer, texts2, layer_idx=teacher_layer_idx)
student_embeds1 = get_word_embeds(student_model, student_tokenizer, texts1, layer_idx=-1)
student_embeds2 = get_word_embeds(student_model, student_tokenizer, texts2, layer_idx=-1)
teacher_scores1 = [attention(e1, e2, 1) for e1, e2 in zip(teacher_embeds1, teacher_embeds2)]
student_scores1 = [attention(e1, e2, 1) for e1, e2 in zip(student_embeds1, student_embeds2)]
loss1 = torch.stack([kl_div_loss(ts, ss, temperature=1.) for ts, ss in zip(teacher_scores1, student_scores1)]).mean()
teacher_scores2 = [attention(e2, e1, 1) for e1, e2 in zip(teacher_embeds1, teacher_embeds2)]
student_scores2 = [attention(e2, e1, 1) for e1, e2 in zip(student_embeds1, student_embeds2)]
loss2 = torch.stack([kl_div_loss(ts, ss, temperature=1.) for ts, ss in zip(teacher_scores2, student_scores2)]).mean()
loss = (loss1 + loss2) * 0.5
optimizer.zero_grad()
loss.backward()
optimizer.step()
log = {'loss': loss.item(), 'loss1': loss.item(), 'loss2': loss2.item()}
wandb.log(log)
pbar.set_postfix(log)
def test_epoch(
teacher_model, teacher_tokenizer,
student_model, student_tokenizer,
test_loader,
teacher_layer_idx,
):
student_model.eval()
test_loss, num_data = 0, 0
for texts1, texts2 in test_loader:
with torch.no_grad():
teacher_embeds1 = get_word_embeds(teacher_model, teacher_tokenizer, texts1, layer_idx=teacher_layer_idx)
teacher_embeds2 = get_word_embeds(teacher_model, teacher_tokenizer, texts2, layer_idx=teacher_layer_idx)
student_embeds1 = get_word_embeds(student_model, student_tokenizer, texts1, layer_idx=-1)
student_embeds2 = get_word_embeds(student_model, student_tokenizer, texts2, layer_idx=-1)
teacher_scores1 = [attention(e1, e2, 1) for e1, e2 in zip(teacher_embeds1, teacher_embeds2)]
student_scores1 = [attention(e1, e2, 1) for e1, e2 in zip(student_embeds1, student_embeds2)]
loss1 = torch.stack([kl_div_loss(ts, ss, temperature=1.) for ts, ss in zip(teacher_scores1, student_scores1)]).mean()
teacher_scores2 = [attention(e2, e1, 1) for e1, e2 in zip(teacher_embeds1, teacher_embeds2)]
student_scores2 = [attention(e2, e1, 1) for e1, e2 in zip(student_embeds1, student_embeds2)]
loss2 = torch.stack([kl_div_loss(ts, ss, temperature=1.) for ts, ss in zip(teacher_scores2, student_scores2)]).mean()
loss = (loss1 + loss2) * 0.5
batch_size = len(texts1)
test_loss += loss.item() * batch_size
num_data += batch_size
test_loss /= num_data
return test_loss
wandb.init(project='bert-score-distillation')
best_loss = 1e10
for ep in range(cfg.num_epochs):
train_epoch(teacher_model, teacher_tokenizer, student_model, student_tokenizer, train_loader, cfg.teacher_layer_idx)
test_loss = test_epoch(teacher_model, teacher_tokenizer, student_model, student_tokenizer, test_loader, cfg.teacher_layer_idx)
print(f'ep {ep:02d} | loss {test_loss:.3f}')
if test_loss < best_loss:
student_model.save_pretrained('checkpoint')
student_tokenizer.save_pretrained('checkpoint')
best_loss = test_loss
wandb.log({'test_loss': test_loss})
3.6. Evaluateยถ
ํ์ต ๊ฒฐ๊ณผ teacher model๊ณผ์ BERTScore ์๊ด๊ด๊ณ๊ฐ 0.806์์ 0.936์ผ๋ก ํฅ์
def calculate_score(scorer, loader):
scores = []
for texts1, texts2 in tqdm(loader):
P, R, F = scorer.score(texts1, texts2)
scores += F.tolist()
return scores
teacher_scorer = BERTScorer(model_type=cfg.teacher_name, num_layers=cfg.teacher_layer_idx)
student_scorer = BERTScorer(model_type=cfg.student_name, num_layers=6)
distilled_student_scorer = BERTScorer(model_type='checkpoint', num_layers=6)
teacher_scores = calculate_score(teacher_scorer, test_loader)
student_scores = calculate_score(student_scorer, test_loader)
distilled_scores = calculate_score(distilled_student_scorer, test_loader)
scores = pd.DataFrame({'teacher': teacher_scores, 'student': student_scores, 'distilled': distilled_scores})
scores.corr().round(3)
scatterplot ์์์๋ distillationํ ํ์ teacher์ BERTScore๋ฅผ ๋ ์ ๋ฐ๋ผํ๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
Referenceยถ
[1] Shen, Z., Lo, K., Wang, L. L., Kuehl, B., Weld, D. S., & Downey, D. (2022). VILA: Improving structured content extraction from scientific PDFs using visual layout groups.ย Transactions of the Association for Computational Linguistics,ย 10, 376-392.ISO 690
[2] Li, M., Xu, Y., Cui, L., Huang, S., Wei, F., Li, Z., & Zhou, M. (2020). DocBank: A benchmark dataset for document layout analysis.ย arXiv preprint arXiv:2006.01038.
[3] Zhang, T., Kishore, V., Wu, F., Weinberger, K. Q., & Artzi, Y. (2019). Bertscore: Evaluating text generation with bert.ย arXiv preprint arXiv:1904.09675.