Skip to content
Snippets Groups Projects
Commit 9833cc8a authored by Zhengxiao Du's avatar Zhengxiao Du
Browse files

Add GLM Model

parent 8584bf91
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn as nn
from .base_model import BaseModel
class GLMModel(BaseModel):
def __init__(self, args, transformer=None):
super().__init__(args, transformer=transformer)
self.transformer.block_position_embeddings = torch.nn.Embedding(args.max_sequence_length, args.hidden_size)
torch.nn.init.normal_(self.transformer.block_position_embeddings.weight, mean=0.0, std=0.02)
def position_embedding_forward(self, position_ids, *other_tensors):
position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]
position_embeddings = self.transformer.position_embeddings(position_ids)
block_position_embeddings = self.transformer.block_position_embeddings(block_position_ids)
return position_embeddings + block_position_embeddings
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment