์๊ธฐํ - LLM for Daicon
Contents
์๊ธฐํ - LLM for Daiconยถ
Causal LMยถ
what is causal langauge model?ยถ
GPT, llama, Alpaca, palm โฆetc
์ด์ ์ ์ ๋ ฅ ํ ํฐ์ ํตํด ๋ค์ ํ ํฐ์ ์์ฑํ๋ ์๊ธฐํ๊ท์ ๋ชจ๋ธ์ด๋ค. ์ด์ ์ ๋ฌธ๋งฅ์ ๋ฐ์ํ์ฌ ์ ๋ ฅ์ ๋ํ ์ถ๋ ฅ์ ๋ด๋๊ธฐ ๋๋ฌธ์ Causal LanguageModel๋ก ๋ถ๋ฆฐ๋ค.
Prompt Tuningยถ
ํ๋กฌํํธ ํ๋์ ํ์ธํ๋์ ํ ๊ฐ๋๋ก์, auto regressive ํ ๋ชจ๋ธ์ ์ธํ๋ฐ์ดํฐ์ ํน์ ํ ์ธ์คํฐ๋ญ์ ์ ์ฃผ์ด ์ํ๋ ๋ต๋ณ์ ์์ฑํ๊ฒ ํ๋ ํ์คํฌ์ด๋ค.
Alpaca Finetuning
Alpaca-7b ๋ชจ๋ธ์ ํ์ธ ํ๋ํ๋ ๊ฒ ์์ฒด๋ ์ฌ์ค ์ด๋ ค์ด ํ์คํฌ๊ฐ ์๋์๊ณ , ์ฝ๋ ํ ์ค๋ก๋ ๊ฐ๋ฅํ ์ผ์ด์๋ค.
์๋์ ํฌ๋งท์ ๋ง์ถ์ด jsonํํ๋ก ๋ฐ์ดํฐ๋ฅผ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ ํ๋์ ์ํค๋ฉด ๋์๋ค.
### ํ์ธ ํ๋ ์ ํ๋กฌํํธ ํฌ๋งท def generate_prompt(data_point): return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501 ### Instruction: {data_point["instruction"]} ### Input: {data_point["input"]} ### Response: {data_point["output"]}""" ### inference ์์ ํ๋กฌํํธ ํฌ๋งท def create_prompt(data_point): return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501 ### Instruction: {data_point["instruction"]} ### Input: {data_point["input"]} ### Response: """ ### inference ๋ต๋ณ ์์ฑ ํจ์๋ค def generate_response(prompt: str, model: model): encoding = tokenizer(prompt, return_tensors="pt") input_ids = encoding["input_ids"].to(DEVICE) generation_config = GenerationConfig( temperature=0.1, top_p=0.75, repetition_penalty=1.1, ) with torch.inference_mode(): return model.generate( input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, output_scores=True, max_new_tokens=1, ) def format_response(response) -> str: decoded_output = tokenizer.decode(response.sequences[0]) response = decoded_output.split("### Response:")[1].strip() return "\n".join(textwrap.wrap(response)) def ask_alpaca(prompt, model: model) -> str: prompt = create_prompt(prompt) response = generate_response(prompt, model) # ๊ต์ฅํ ๋ง์ ๋ต๋ณ์ด ์์ฑ๋จ response = format_response(response) return response ###
Loss Function of Auto Regressive Model - ์ฌ๋ด
์์ฑ ๋ชจ๋ธ์ loss function์ ์ด๋ป๊ฒ ๊ตด๋ฌ๊ฐ๋์ง ๊ถ๊ธํด์ ํ๋ฒ ์ฐพ์๋ด
## huggning face์ nlp course์ ์๋ ์ฝ๋ from torch.nn import CrossEntropyLoss import torch def keytoken_weighted_loss(inputs, logits, keytoken_ids, alpha=1.0): # Shift so that tokens < n predict n shift_labels = inputs[..., 1:].contiguous() shift_logits = logits[..., :-1, :].contiguous() # Calculate per-token loss loss_fct = CrossEntropyLoss(reduce=False) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # Resize and average loss per sample loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1) # Calculate and scale weighting weights = torch.stack([(inputs == kt).float() for kt in keytoken_ids]).sum( axis=[0, 2] ) weights = alpha * (1.0 + weights) # Calculate weighted average weighted_loss = (loss_per_sample * weights).mean() return weighted_loss
Token_1 ์ ํ๊ฒ์ Token_2 ์ด๋ค.
๋ฐ๋ผ์ ํ๊ฒ ํ ํฐ์ 2๋ฒ์งธ๋ถํฐ ๋ง์ง๋ง๊น์ง, ์ ๋ ฅ ํ ํฐ์ 1๋ฒ์งธ๋ถํฐ ๋ง์ง๋ง ์ด์ ํ ํฐ๊น์ง.
๊ฐ loss๋ฅผ ๊ณ์ฐํด์ค ์ดํ์๋, ๋ฑ์ฅ ๋น๋๋ฅผ ๊ฐ์ค์น๋ก ์ฌ์ฉํ์ฌ ๋ชจ๋ ์ํ์ ๋ํ ๊ฐ์ค ํ๊ท ์ ๊ณ์ฐํด์ค๋ค.
p-tunig for classificationยถ
hugging face ์ peft์ ์์
peft_config = PromptTuningConfig( task_type=TaskType.CAUSAL_LM, prompt_tuning_init=PromptTuningInit.TEXT, num_virtual_tokens=8, prompt_tuning_init_text="Classify if the tweet is a complaint or not:", ## ์ธ์คํธ๋ญ์ ์ ๋ฐ๋ก ์ค ์ ์๋ค. tokenizer_name_or_path=model_name_or_path, ) ... f'{text_column} : {"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?"} Label : ',
Alpaca์ ์์ - ์ ์ฐธ๊ณ
Dacon์ ์ ์ฉยถ
๋ฐฐ๊ฒฝยถ
์๊ณ ์ ์น์ ์ฌ๋ถ๋ฅผ ๋ฌผ์ด๋ณด๋ ๊ณผ์ ์์ผ๋ฏ๋ก, ๋จ์ํ classification ๋ณด๋ค ๋ณต์กํ ๋ฌธ์ ๋ผ๊ณ ์๊ฐํ๋ค.
์ด ๋ฌธ์ ๋ฅผ alpaca์ prompt-tuning์ผ๋ก ์ ๊ทผํ์ฌ ํด๊ฒฐ ๊ฐ๋ฅํ ๋ฌธ์ ๋ผ๊ณ ์๊ฐํจ.
๋ฐ์ดํฐ ํ์ ๋ฐ ์ ์ฒ๋ฆฌยถ
๋ฐ์ดํฐ ๊ธธ์ด ํ์
train length

test length

๋ผ๋ฒจ ํ์

ner ํ๊ทธ ํ์
first_party

second_party

first_party, second_party์ ner์ ํ์ํ์ฌ ์ธ๋ช , ์ง๋ช , ์กฐ์ง ๋ฑ๋ฑ์ ๋ง์ถ์ด ๋ฃฐ ๋ฒ ์ด์ค๋ก ์ฒ๋ฆฌํ์ฌ facts ์์ ๋ฑ์ฅํ๋์ง ์ฌ๋ถ๋ฅผ ์ฐพ๊ณ ์ ํจ
full name์ด ์ด๋ฏธ facts์์ ์๋ ๊ฒฝ์ฐ
์ด ๊ฒฝ์ฐ์ ner ์๊ด ์์ด ๋ฐ๋ก prompt๋ฅผ ์์ฑ ๊ฐ๋ฅํ๋ฏ๋ก ๋ฐ๋ก ๋นผ์ด ๊ด๋ฆฌ
1550 ๊ฐ์ row
full_name์ด facts์์ ์๋ ๊ฒฝ์ฐ
ner์ ์ ์ฉํ ํ์, ner ํ๊ทธ๊ฐ ํ๋๋ง ์๋ ๋ฐ์ดํฐ๋ค์ ๋ชจ์์ ์ฒ๋ฆฌ
PER ํ๊ทธ : ๊ฐ๋จํ ํด๋ ์ง ์ดํ์ first_name๊ณผ last_name์ด ์๋์ง๋ฅผ ์ฒดํฌํ์๋ค.
ORG, LOC, MISC : ๊ฐ๋จํ ํด๋ ์ง ์ดํ์ n-gram์ ์ ์ฉํ์ฌ facts์์ ์๋์ง ์ฌ๋ถ๋ฅผ ์ฒดํฌํ์๋ค.
498 row
full_name์ด facts์์ ์๊ณ ner tag๊ฐ ์ฌ๋ฌ๊ฐ์ธ ๊ฒฝ์ฐ
์ด ๊ฒฝ์ฐ์๋ 2๋ฒ์ ๊ฒฝ์ฐ๋ณด๋ค ์ผ๋ฐ์ ์ผ๋ก ์ฒ๋ฆฌํด์ผ ํ์ผ๋ฏ๋ก n-gram์ ์ ์ฉํ์ฌ facts์์ ์๋์ง ์ฌ๋ถ๋ฅผ ์ฒดํฌํ์๋ค.
430 rows
ํ๋กฌํํธ ์์ฑ
๋ง์ฝ first_party๊ฐ ์์ผ๋ฉด, fโDoes {first_party} winโ ์ ํ๋กฌํํธ๋ฅผ ๋ฃ์ด์ฃผ๊ณ , second_party๊ฐ ์์ผ๋ฉด fโDoes {second_party} winโ์ ๋ฃ์ด์ค, ๋ ๋ค ์๋ ๊ฒฝ์ฐ๋ โDoes Complainant win?โ ํ๋กฌํํธ๋ฅผ ๋ฃ์ด์ฃผ์๋ค.
๋ต๋ณ
โyesโ
first_party์ first_party_winner ๊ฐ ๊ฐ์ ๋ผ๋ฒจ 1==1, 0==0์ธ ๊ฒฝ์ฐ์ ํ๋กฌํํธ์ โComplainantโ๊ฐ ๋ฑ์ฅํ๊ณ , first_party_winner๊ฐ 1์ธ ๊ฒฝ์ฐ
โnoโ
first_party๊ฐ 0์ด๊ณ first_party_winner๊ฐ 1์ธ ๊ฒฝ์ฐ์๋ โNoโ, first_party๊ฐ 1์ด๊ณ first_party_winner๊ฐ 0์ธ ๊ฒฝ์ฐ, โComplainantโ๊ฐ ๋ฑ์ฅํ๊ณ first_party_winner๊ฐ 0์ธ ๊ฒฝ์ฐ