์†๊ธฐํ›ˆ - LLM for Daiconยถ

Causal LMยถ

what is causal langauge model?ยถ

  • GPT, llama, Alpaca, palm โ€ฆetc

llm_paint.png

  • ์ด์ „์˜ ์ž…๋ ฅ ํ† ํฐ์„ ํ†ตํ•ด ๋‹ค์Œ ํ† ํฐ์„ ์ƒ์„ฑํ•˜๋Š” ์ž๊ธฐํšŒ๊ท€์  ๋ชจ๋ธ์ด๋‹ค. ์ด์ „์˜ ๋ฌธ๋งฅ์„ ๋ฐ˜์˜ํ•˜์—ฌ ์ž…๋ ฅ์— ๋Œ€ํ•œ ์ถœ๋ ฅ์„ ๋‚ด๋†“๊ธฐ ๋•Œ๋ฌธ์— Causal LanguageModel๋กœ ๋ถˆ๋ฆฐ๋‹ค.

Prompt Tuningยถ

ํ”„๋กฌํ”„ํŠธ ํŠœ๋‹์€ ํŒŒ์ธํŠœ๋‹์˜ ํ•œ ๊ฐˆ๋ž˜๋กœ์„œ, auto regressive ํ•œ ๋ชจ๋ธ์— ์ธํ’‹๋ฐ์ดํ„ฐ์— ํŠน์ •ํ•œ ์ธ์Šคํ„ฐ๋Ÿญ์…˜์„ ์ฃผ์–ด ์›ํ•˜๋Š” ๋‹ต๋ณ€์„ ์ƒ์„ฑํ•˜๊ฒŒ ํ•˜๋Š” ํƒœ์Šคํฌ์ด๋‹ค.

  • Alpaca Finetuning

    • Alpaca-7b ๋ชจ๋ธ์„ ํŒŒ์ธ ํŠœ๋‹ํ•˜๋Š” ๊ฒƒ ์ž์ฒด๋Š” ์‚ฌ์‹ค ์–ด๋ ค์šด ํƒœ์Šคํฌ๊ฐ€ ์•„๋‹ˆ์—ˆ๊ณ , ์ฝ”๋“œ ํ•œ ์ค„๋กœ๋„ ๊ฐ€๋Šฅํ•œ ์ผ์ด์—ˆ๋‹ค.

    • ์•„๋ž˜์˜ ํฌ๋งท์— ๋งž์ถ”์–ด jsonํ˜•ํƒœ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์™€ ํŠœ๋‹์„ ์‹œํ‚ค๋ฉด ๋˜์—ˆ๋‹ค.

      ### ํŒŒ์ธ ํŠœ๋‹ ์‹œ ํ”„๋กฌํ”„ํŠธ ํฌ๋งท
      
      def generate_prompt(data_point):
          return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
      ### Instruction:
      {data_point["instruction"]}
      ### Input:
      {data_point["input"]}
      ### Response:
      {data_point["output"]}"""
      
      ### inference ์‹œ์˜ ํ”„๋กฌํ”„ํŠธ ํฌ๋งท
      
      def create_prompt(data_point):
          return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
      ### Instruction:
      {data_point["instruction"]}
      ### Input:
      {data_point["input"]}
      ### Response:
      """
      
      ### inference ๋‹ต๋ณ€ ์ƒ์„ฑ ํ•จ์ˆ˜๋“ค
      
      def generate_response(prompt: str, model: model):
          encoding = tokenizer(prompt, return_tensors="pt")
          input_ids = encoding["input_ids"].to(DEVICE)
       
          generation_config = GenerationConfig(
              temperature=0.1,
              top_p=0.75,
              repetition_penalty=1.1,
          )
          with torch.inference_mode():
              return model.generate(
                  input_ids=input_ids,
                  generation_config=generation_config,
                  return_dict_in_generate=True,
                  output_scores=True,
                  max_new_tokens=1,
              )
      
      def format_response(response) -> str:
          decoded_output = tokenizer.decode(response.sequences[0]) 
          response = decoded_output.split("### Response:")[1].strip()
          return "\n".join(textwrap.wrap(response))
      
      def ask_alpaca(prompt, model: model) -> str:
          prompt = create_prompt(prompt)
          response = generate_response(prompt, model) # ๊ต‰์žฅํžˆ ๋งŽ์€ ๋‹ต๋ณ€์ด ์ƒ์„ฑ๋จ
          response = format_response(response)
          return response
      
      ### 
      
  • Loss Function of Auto Regressive Model - ์—ฌ๋‹ด

    • ์ƒ์„ฑ ๋ชจ๋ธ์˜ loss function์€ ์–ด๋–ป๊ฒŒ ๊ตด๋Ÿฌ๊ฐ€๋Š”์ง€ ๊ถ๊ธˆํ•ด์„œ ํ•œ๋ฒˆ ์ฐพ์•„๋ด„

    ## huggning face์— nlp course์— ์žˆ๋Š” ์ฝ”๋“œ
    
    from torch.nn import CrossEntropyLoss
    import torch
    
    def keytoken_weighted_loss(inputs, logits, keytoken_ids, alpha=1.0):
        # Shift so that tokens < n predict n
        shift_labels = inputs[..., 1:].contiguous()
        shift_logits = logits[..., :-1, :].contiguous()
        # Calculate per-token loss
        loss_fct = CrossEntropyLoss(reduce=False)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        # Resize and average loss per sample
        loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)
        # Calculate and scale weighting
        weights = torch.stack([(inputs == kt).float() for kt in keytoken_ids]).sum(
            axis=[0, 2]
        )
        weights = alpha * (1.0 + weights)
        # Calculate weighted average
        weighted_loss = (loss_per_sample * weights).mean()
        return weighted_loss
    
    • Token_1 ์˜ ํƒ€๊ฒŸ์€ Token_2 ์ด๋‹ค.

    • ๋”ฐ๋ผ์„œ ํƒ€๊ฒŸ ํ† ํฐ์€ 2๋ฒˆ์งธ๋ถ€ํ„ฐ ๋งˆ์ง€๋ง‰๊นŒ์ง€, ์ž…๋ ฅ ํ† ํฐ์€ 1๋ฒˆ์งธ๋ถ€ํ„ฐ ๋งˆ์ง€๋ง‰ ์ด์ „ ํ† ํฐ๊นŒ์ง€.

    • ๊ฐ loss๋ฅผ ๊ณ„์‚ฐํ•ด์ค€ ์ดํ›„์—๋Š”, ๋“ฑ์žฅ ๋นˆ๋„๋ฅผ ๊ฐ€์ค‘์น˜๋กœ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋“  ์ƒ˜ํ”Œ์— ๋Œ€ํ•œ ๊ฐ€์ค‘ ํ‰๊ท ์„ ๊ณ„์‚ฐํ•ด์ค€๋‹ค.

p-tunig for classificationยถ

  • hugging face ์˜ peft์˜ ์˜ˆ์‹œ

    peft_config = PromptTuningConfig(
        task_type=TaskType.CAUSAL_LM,
        prompt_tuning_init=PromptTuningInit.TEXT,
        num_virtual_tokens=8,
        prompt_tuning_init_text="Classify if the tweet is a complaint or not:", ## ์ธ์ŠคํŠธ๋Ÿญ์…˜์„ ๋”ฐ๋กœ ์ค„ ์ˆ˜ ์žˆ๋‹ค.
        tokenizer_name_or_path=model_name_or_path,
    )
    
    ...
    
    f'{text_column} : {"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?"} Label : ',
    
  • Alpaca์˜ ์˜ˆ์‹œ - ์œ„ ์ฐธ๊ณ 

Dacon์— ์ ์šฉยถ

๋ฐฐ๊ฒฝยถ

  • ์›๊ณ ์˜ ์Šน์†Œ ์—ฌ๋ถ€๋ฅผ ๋ฌผ์–ด๋ณด๋Š” ๊ณผ์ œ์˜€์œผ๋ฏ€๋กœ, ๋‹จ์ˆœํ•œ classification ๋ณด๋‹ค ๋ณต์žกํ•œ ๋ฌธ์ œ๋ผ๊ณ  ์ƒ๊ฐํ–ˆ๋‹ค.

  • ์ด ๋ฌธ์ œ๋ฅผ alpaca์— prompt-tuning์œผ๋กœ ์ ‘๊ทผํ•˜์—ฌ ํ•ด๊ฒฐ ๊ฐ€๋Šฅํ•œ ๋ฌธ์ œ๋ผ๊ณ  ์ƒ๊ฐํ•จ.

๋ฐ์ดํ„ฐ ํƒ์ƒ‰ ๋ฐ ์ „์ฒ˜๋ฆฌยถ

  1. ๋ฐ์ดํ„ฐ ๊ธธ์ด ํƒ์ƒ‰

    • train length

    ![Untitled](CausalLM%20for%20parameter%20tuning%20and%20classification(f%20bfd645a65c9c442e9a3d1be2a7f987b1/Untitled.png)

    • test length

    ![Untitled](CausalLM%20for%20parameter%20tuning%20and%20classification(f%20bfd645a65c9c442e9a3d1be2a7f987b1/Untitled%201.png)

  2. ๋ผ๋ฒจ ํƒ์ƒ‰

    ![Untitled](CausalLM%20for%20parameter%20tuning%20and%20classification(f%20bfd645a65c9c442e9a3d1be2a7f987b1/Untitled%202.png)

  3. ner ํƒœ๊ทธ ํƒ์ƒ‰

    • first_party

    ![Untitled](CausalLM%20for%20parameter%20tuning%20and%20classification(f%20bfd645a65c9c442e9a3d1be2a7f987b1/Untitled%203.png)

    • second_party

    ![Untitled](CausalLM%20for%20parameter%20tuning%20and%20classification(f%20bfd645a65c9c442e9a3d1be2a7f987b1/Untitled%204.png)

    • first_party, second_party์˜ ner์„ ํƒ์ƒ‰ํ•˜์—ฌ ์ธ๋ช…, ์ง€๋ช…, ์กฐ์ง ๋“ฑ๋“ฑ์— ๋งž์ถ”์–ด ๋ฃฐ ๋ฒ ์ด์Šค๋กœ ์ฒ˜๋ฆฌํ•˜์—ฌ facts ์•ˆ์— ๋“ฑ์žฅํ•˜๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ์ฐพ๊ณ ์ž ํ•จ

      1. full name์ด ์ด๋ฏธ facts์•ˆ์— ์žˆ๋Š” ๊ฒฝ์šฐ

        • ์ด ๊ฒฝ์šฐ์—” ner ์ƒ๊ด€ ์—†์ด ๋ฐ”๋กœ prompt๋ฅผ ์ƒ์„ฑ ๊ฐ€๋Šฅํ•˜๋ฏ€๋กœ ๋”ฐ๋กœ ๋นผ์–ด ๊ด€๋ฆฌ

        • 1550 ๊ฐœ์˜ row

      2. full_name์ด facts์•ˆ์— ์—†๋Š” ๊ฒฝ์šฐ

        ner์„ ์ ์šฉํ•œ ํ›„์—, ner ํƒœ๊ทธ๊ฐ€ ํ•˜๋‚˜๋งŒ ์žˆ๋Š” ๋ฐ์ดํ„ฐ๋“ค์„ ๋ชจ์•„์„œ ์ฒ˜๋ฆฌ

        • PER ํƒœ๊ทธ : ๊ฐ„๋‹จํ•œ ํด๋ Œ์ง• ์ดํ›„์— first_name๊ณผ last_name์ด ์žˆ๋Š”์ง€๋ฅผ ์ฒดํฌํ•˜์˜€๋‹ค.

        • ORG, LOC, MISC : ๊ฐ„๋‹จํ•œ ํด๋ Œ์ง• ์ดํ›„์— n-gram์„ ์ ์šฉํ•˜์—ฌ facts์•ˆ์— ์žˆ๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ์ฒดํฌํ•˜์˜€๋‹ค.

        • 498 row

      3. full_name์ด facts์•ˆ์— ์—†๊ณ  ner tag๊ฐ€ ์—ฌ๋Ÿฌ๊ฐœ์ธ ๊ฒฝ์šฐ

        • ์ด ๊ฒฝ์šฐ์—๋Š” 2๋ฒˆ์˜ ๊ฒฝ์šฐ๋ณด๋‹ค ์ผ๋ฐ˜์ ์œผ๋กœ ์ฒ˜๋ฆฌํ•ด์•ผ ํ–ˆ์œผ๋ฏ€๋กœ n-gram์„ ์ ์šฉํ•˜์—ฌ facts์•ˆ์— ์žˆ๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ์ฒดํฌํ•˜์˜€๋‹ค.

        • 430 rows

  4. ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ

    • ๋งŒ์•ฝ first_party๊ฐ€ ์žˆ์œผ๋ฉด, fโ€œDoes {first_party} winโ€ ์˜ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋„ฃ์–ด์ฃผ๊ณ , second_party๊ฐ€ ์žˆ์œผ๋ฉด fโ€Does {second_party} winโ€์„ ๋„ฃ์–ด์คŒ, ๋‘˜ ๋‹ค ์—†๋Š” ๊ฒฝ์šฐ๋Š” โ€œDoes Complainant win?โ€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋„ฃ์–ด์ฃผ์—ˆ๋‹ค.

    • ๋‹ต๋ณ€

      • โ€˜yesโ€™

        • first_party์™€ first_party_winner ๊ฐ€ ๊ฐ™์€ ๋ผ๋ฒจ 1==1, 0==0์ธ ๊ฒฝ์šฐ์™€ ํ”„๋กฌํ”„ํŠธ์— โ€˜Complainantโ€™๊ฐ€ ๋“ฑ์žฅํ•˜๊ณ , first_party_winner๊ฐ€ 1์ธ ๊ฒฝ์šฐ

      • โ€˜noโ€™

        • first_party๊ฐ€ 0์ด๊ณ  first_party_winner๊ฐ€ 1์ธ ๊ฒฝ์šฐ์—๋Š” โ€˜Noโ€, first_party๊ฐ€ 1์ด๊ณ  first_party_winner๊ฐ€ 0์ธ ๊ฒฝ์šฐ, โ€˜Complainantโ€™๊ฐ€ ๋“ฑ์žฅํ•˜๊ณ  first_party_winner๊ฐ€ 0์ธ ๊ฒฝ์šฐ