From 9833cc8abaf596f5e4a333a14d36b04824bb942f Mon Sep 17 00:00:00 2001 From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn> Date: Mon, 18 Oct 2021 15:48:05 +0800 Subject: [PATCH] Add GLM Model --- model/glm_model.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 model/glm_model.py diff --git a/model/glm_model.py b/model/glm_model.py new file mode 100644 index 0000000..0edf9b9 --- /dev/null +++ b/model/glm_model.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + +from .base_model import BaseModel + + +class GLMModel(BaseModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.transformer.block_position_embeddings = torch.nn.Embedding(args.max_sequence_length, args.hidden_size) + torch.nn.init.normal_(self.transformer.block_position_embeddings.weight, mean=0.0, std=0.02) + + def position_embedding_forward(self, position_ids, *other_tensors): + position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1] + position_embeddings = self.transformer.position_embeddings(position_ids) + block_position_embeddings = self.transformer.block_position_embeddings(block_position_ids) + return position_embeddings + block_position_embeddings -- GitLab