From 9833cc8abaf596f5e4a333a14d36b04824bb942f Mon Sep 17 00:00:00 2001
From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn>
Date: Mon, 18 Oct 2021 15:48:05 +0800
Subject: [PATCH] Add GLM Model

---
 model/glm_model.py | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)
 create mode 100644 model/glm_model.py

diff --git a/model/glm_model.py b/model/glm_model.py
new file mode 100644
index 0000000..0edf9b9
--- /dev/null
+++ b/model/glm_model.py
@@ -0,0 +1,17 @@
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+
+
+class GLMModel(BaseModel):
+    def __init__(self, args, transformer=None):
+        super().__init__(args, transformer=transformer)
+        self.transformer.block_position_embeddings = torch.nn.Embedding(args.max_sequence_length, args.hidden_size)
+        torch.nn.init.normal_(self.transformer.block_position_embeddings.weight, mean=0.0, std=0.02)
+
+    def position_embedding_forward(self, position_ids, *other_tensors):
+        position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]
+        position_embeddings = self.transformer.position_embeddings(position_ids)
+        block_position_embeddings = self.transformer.block_position_embeddings(block_position_ids)
+        return position_embeddings + block_position_embeddings
-- 
GitLab