Prompt Tuning 介绍

本文旨在结合原文与Peft源码,介绍Prompt Tuning

可训练参数的位置

Prompt Tuning只引入一层参数,位于Transformer的输入层(图中绿色部分)

参数结构

1
2
3
4
5
6
7
8
9
10
11
# peft_model.py 
# line 389
# named_param = 'word_embeddings.weight'
self.word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", ""))


# peft_model.py
# line 393
# config是PromptTuning的参数
# self.word_embedding是原模型最底层Transformer的word_embedding
prompt_encoder = PromptEmbedding(config, self.word_embeddings)

prompt_encoder对象用于推理,可训练的实数参数的数量为

T=EPT = E\cdot P

  • TT:可训练参数数量
  • EEtoken embedding的维度
  • PP​:预设的prompt长度

The parameter cost of our method is EP, where E is the token embedding dimension and P is the prompt length.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class PromptEmbedding(torch.nn.Module):
def __init__(self, config, word_embeddings):
super().__init__()
total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
if config.prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode:
from transformers import AutoTokenizer
tokenizer_kwargs = config.tokenizer_kwargs or {}
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path, **tokenizer_kwargs)
init_text = config.prompt_tuning_init_text
init_token_ids = tokenizer(init_text)["input_ids"]
num_text_tokens = len(init_token_ids)
if num_text_tokens > total_virtual_tokens:
init_token_ids = init_token_ids[:total_virtual_tokens]
elif num_text_tokens < total_virtual_tokens:
num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
init_token_ids = init_token_ids * num_reps
init_token_ids = init_token_ids[:total_virtual_tokens]
init_token_ids = torch.LongTensor(init_token_ids).to(word_embeddings.weight.device)

word_embedding_weights = word_embeddings(init_token_ids).detach().clone()
word_embedding_weights = word_embedding_weights.to(torch.float32)
self.embedding.weight = torch.nn.Parameter(word_embedding_weights)

total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules代表允许prompt tuning插入的tokens长度

config.prompt_tuning_init_text是初始化的prompt文本

整段代码的工作逻辑是

  1. 传入预设的prompt tuning参数以及原大模型底层transformer的embedding层
  2. 获取原大模型使用的tokenizer
  3. 利用tokenizer将prompt文本转化成token ids,并截取或扩充至total_virtual_tokens长度
  4. 使用resize后的token ids从embeding选择初始化参数(peft的选择)

以下是prompt tuning关于初始化的原文**(2.1 Design Decisions)**:

A more sophisticated option is to initialize each prompt token to an embedding drawn from the model’s vocabulary.

Since we want the model to produce these tokens in the output, initializing the prompt with the embeddings of the valid target tokens should prime the model to restrict its output to the legal output classes.

前向传播流程

假如word vector的长度是1024,引入Prompt Tuning之前,input.shape=8×8000×1024input.shape=8\times8000\times1024,引入长度为8的Promt Tuning,得到参数维度 prompt_encoder.embedding.weight.shape=8×8×1024prompt\_encoder.embedding.weight.shape = 8 \times 8\times 1024,拼接成新输入input^.shape=8×(8000+8)×1024=8×8008×1024\hat{input}.shape=8 \times (8000+8)\times1024=8 \times 8008\times1024

1
2
3
4
5
# peft_model.py
# line 1123~1130
inputs_embeds = self.word_embeddings( input_ids) # 原始输入
prompts = prompt_encoder.embedding.weight.repeat( batch_size, 1, 1) # prompt tuning 参数
inputs_embeds = torch.cat( ( prompts, inputs_embeds), dim=1) # 组合输入

参考

huggingface/peft

The Power of Scale for Parameter-Efficient Prompt Tuning


Prompt Tuning 介绍
https://www.ydhuyong.online/2024/03/11/01_prompt_tuning/
作者
Yong
发布于
2024年3月11日
更新于
2024年3月19日
许可协议