from typing import List
from rich.tree import Tree as RichTree
from rich import print as richprint
# import torch
import torch.nn as nn
import re
from collections import OrderedDict
from .vis_logging import get_logger
logger = get_logger(__name__)
class ModuleTree(RichTree):
def __init__(
self,
module_name=None,
info=None,
is_param_node=False,
type_color="green",
param_color="red",
main_color="white",
style = "tree",
guide_style = "tree.line",
expanded=True,
highlight=False,
):
self.module_name = module_name
self.info = info
self.is_param_node = is_param_node
self.type_color = type_color
self.param_color = param_color
self.main_color = main_color
label = self.set_label()
super().__init__(label,style=style,guide_style=guide_style,expanded=expanded,highlight=highlight)
def add(
self,
module_name=None,
info=None,
is_param_node=False,
type_color="green",
param_color="red",
main_color="white",
style=None,
guide_style=None,
expanded=True,
highlight=False,
):
node = ModuleTree(
module_name,
info,
is_param_node,
type_color,
param_color,
main_color,
style=self.style if style is None else style,
guide_style=self.guide_style if guide_style is None else guide_style,
expanded=expanded,
highlight=self.highlight if highlight is None else highlight,
)
self.children.append(node)
return node
def set_label(self):
if self.module_name is not None:
label = f"[{self.main_color}]{self.module_name}"
else:
label = ""
if self.info is not None:
if not self.is_param_node:
label += f" [{self.type_color}]({self.info})"
else:
label += f" [{self.param_color}]{self.info}"
self.label = label
return label
[docs]class Visualization(object):
r"""
Better visualization tool for *BIG* pretrained models.
- Better repeated block representation
- Clearer parameter position
- and Visible parameter state.
Args:
plm (:obj:`torch.nn.Module`): The pretrained model, actually can be any pytorch module.
"""
def __init__(self, plm: nn.Module):
self.plm = plm
self.type_color = "green"
self.param_color = "cyan"
self.duplicate_color = "red"
self.normal_color = "white"
self.virtual_color = "orange"
self.not_common_color = "bright_black"
self.no_grad_color = "rgb(0,70,100)"
self.delta_color = "rgb(175,0,255)"
def check_mode(self, ):
if self.keep_non_params and self.common_structure:
raise RuntimeError("keep_non_params can't be used will common_structure. The common structure only contains parameter nodes.")
if self.common_structure:
if self.mapping is None:
raise RuntimeError("Mapping hasn't been given.")
[docs] def structure_graph(self,
rootname="root",
expand_params=False,
keep_non_params=False,
common_structure=False,
mapping=None,
only_common=False,
printTree=True,
):
r"""Draw the structure graph in command line.
Args:
rootname (:obj:`str`) The root node's name.
keep_non_params (:obj:`bool`) Display the modules that does not have parameters, such as nn.Dropout
expand_params (:obj:`bool`) Display parameter infomation (shape, etc) in seperate lines. "
common_structure (:obj:`bool`) Whether convert the structure into a common structure defined in structure_mapping.py. The not common structure will be displayed in grey.
only_common (:obj:`bool`) Whether ignore the modules that are not in common structure. This will result in a more compact view. Default to False.
mapping (:obj:`dict`) The structure mapping. Must provide if common_structure=True.
"""
self.keep_non_params = keep_non_params
self.expand_params = expand_params
self.rootname = rootname
self.only_common = only_common
self.common_structure = common_structure
self.mapping = mapping
self.check_mode()
# root_tree = self.build_tree(rootname)
self.root_tree = ModuleTree(self.rootname)
if common_structure:
self.build_common_tree(self.plm, mapping, self.root_tree)
else:
self.build_tree(self.plm, self.root_tree)
self.prune_tree(self.root_tree)
if not self.expand_params:
self.fold_param_node(self.root_tree)
if printTree:
richprint(self.root_tree)
return self.root_tree
def is_leaf_module(self, module):
r"""[NODOC] Whether the module is a leaf module
"""
return len([n for n,_ in module.named_children()]) == 0
def build_tree(self, module:nn.Module, tree:ModuleTree=None):
r"""[NODOC] build the originial tree structure
"""
if self.is_leaf_module(module):
return
else:
for n,m in module.named_children():
type_info = re.search(r'(?<=\').*(?=\')', str(type(m))).group()
type_info = type_info.split(".")[-1]
newnode = tree.add(n, info=type_info, type_color=self.type_color)
self.add_param_info_node(m, newnode)
self.build_tree(module=m, tree=newnode)
def has_parameter(self, module):
return len([p for p in module.parameters()])>0
[docs] def build_common_tree(self, module:nn.Module, mapping, tree:ModuleTree=None, query="", key_to_root=""):
r""" (Unstable) build the common tree structure
"""
if self.is_leaf_module(module):
if len(query)>0: # the field is not in mapping
if self.has_parameter(module):
# from IPython import embed
# embed(header = "in leaf")
logger.warning(f"Parameter node {query} not found under tree {tree.module_name} and module {module}. Is your mapping correct?") # WARNING
return
else:
for n,m in module.named_children():
new_query = query+n
type_info = re.search(r'(?<=\').*(?=\')', str(type(m))).group()
type_info = type_info.split(".")[-1]
if new_query in mapping or "$" in mapping:
# print("query",new_query)
# from IPython import embed
# embed()
if new_query in mapping:
new_mapping = mapping[new_query]
name = new_mapping["__name__"]
if len(name.split(".")) > 1: # new key contains a hierarchy , then unfold the hierarchy.
# insert virtual node
hierachical_name = name.split(".")
temp_tree = self.find_or_insert(tree, hierachical_name)
newnode = temp_tree.add(hierachical_name[-1], info=type_info, type_color=self.type_color)
elif name=="": # the key not in a predefined common structure
if self.only_common:
continue
else: # add the originial name into the tree
newnode = tree.add(new_query, info=type_info, main_color=self.not_common_color, type_color=self.not_common_color)
else: # a single new key
newnode = self.find_not_insert(tree, [name,""]) # try to find the node
if newnode is not None:
newnode.info = type_info
newnode.type_color = self.type_color
newnode.set_label()
else:
newnode = tree.add(name, info=type_info, type_color=self.type_color)
elif "$" in mapping: # match any thing in the field.
new_mapping = mapping["$"]
newnode = tree.add(n, info=type_info, type_color=self.type_color)
self.add_param_info_node(m, newnode)
self.build_common_tree(module=m, tree=newnode, mapping=new_mapping, key_to_root=key_to_root+"."+new_query)
else:
# try to find from root
# trsf_key = transform(key_to_root.strip("."), self.mapping)
# parent_node = self.find_not_insert(self.root_tree, trsf_key.split(".")+[""])
# if parent_node is not None:
# new_mapping = mapping[new_query]
# newnode = parent_node.add(name, info=type_info, type_color=self.type_color)
# self.build_common_tree(module=m, tree=parent_node, mapping )
# print("notin query",new_query)
# if new_query == "dense":
# from IPython import embed
# embed()
# print(f"::{query},,{new_query}, {list(mapping.keys())}")
new_query += "."
self.build_common_tree(module=m, tree=tree, mapping=mapping, query=new_query, key_to_root=key_to_root)
def find_or_insert(self, tree:ModuleTree, hierachical_name:List[str] ):
r"""[NODOC] Find the node, if not find, insert a virtual node
"""
if len(hierachical_name)==1:
return tree
names = [x.module_name for x in tree.children]
if hierachical_name[0] not in names:
new_node = tree.add(hierachical_name[0], info="Virtual", type_color=self.virtual_color)
else:
for x in tree.children:
if x.module_name == hierachical_name[0]:
new_node = x
break
return self.find_or_insert(new_node, hierachical_name=hierachical_name[1:])
def find_not_insert(self, tree:ModuleTree, hierachical_name:List[str] ):
r"""[NODOC] Find the node but not insert
"""
if len(hierachical_name)==1:
return tree
names = [x.module_name for x in tree.children]
if hierachical_name[0] not in names:
return None
else:
for x in tree.children:
if x.module_name == hierachical_name[0]:
new_node = x
break
return self.find_not_insert(new_node, hierachical_name=hierachical_name[1:])
def fold_param_node(self, t: ModuleTree, p:ModuleTree=None):
r"""[NODOC] place the parameters' infomation node right after the module that contains the parameters.
E.g. w1 (Linear)
-- weight: [32128, 1024]
=>
w1 (Linear) weight: [32128, 1024]
"""
if hasattr(t,"is_param_node") and t.is_param_node:
p.label += t.label
return True # indicate whether should be removed
elif len(t.children) == 0:
if self.keep_non_params:
return False
else:
return True
else:
rm_idx = []
for idx, c in enumerate(t.children):
if self.fold_param_node(t=c, p=t):
rm_idx.append(idx)
t.children = [t.children[i] for i in range(len(t.children)) if i not in rm_idx]
return False
def prune_tree(self, t: ModuleTree):
r"""[NODOC] Calculate the _finger_print of a module as the _finger_print of all child node plus the _finger_print of itself.
The leaf node will have the _finger_print == label.
Merge the different node that as the same _finger_print into a single node.
"""
if len(t.children) == 0:
setattr(t, "_finger_print", t.label)
return
for idx, sub_tree in enumerate(t.children):
self.prune_tree(sub_tree)
t_finger_print = t.label +"::"+";".join([x._finger_print for x in t.children])
setattr(t, "_finger_print", t_finger_print)
nohead_finger_print_dict = OrderedDict()
for child_id, sub_tree in enumerate(t.children):
fname_list = sub_tree._finger_print.split("::")
if len(fname_list)==1:
fname = fname_list[0]
else:
fname = "::".join(fname_list[1:])
if fname not in nohead_finger_print_dict:
nohead_finger_print_dict[fname] = [child_id]
else:
nohead_finger_print_dict[fname].append(child_id)
new_childrens = []
for groupname in nohead_finger_print_dict:
representative_id = nohead_finger_print_dict[groupname][0]
representative = t.children[representative_id]
group_node = [t.children[idx] for idx in nohead_finger_print_dict[groupname]]
representative = self.extract_common_and_join(group_node)
new_childrens.append(representative)
t.children = new_childrens
def extract_common_and_join(self, l:List[ModuleTree]):
r"""[NODOC] Some modules that have the same info (e.g., are all "Linear") have different names (e.g., w1,w2)
Merge them.
E.g. tree1.module_name = "w1", tree1.info = "Linear"; tree2.module_name = "w1", tree2.info = "Linear"
-> representive.module_name = "w1,w2", representive.info = "Linear"
"""
representative = l[0]
if len(l)==1:
return representative
name_list = [x.module_name for x in l]
info_list = [x.info for x in l]
type_hint_dict = OrderedDict()
for x, y in zip(name_list, info_list):
if y not in type_hint_dict:
type_hint_dict[y] = [x]
else:
type_hint_dict[y].append(x)
s = ""
names = ""
typeinfos = ""
for t in type_hint_dict:
group_components = type_hint_dict[t]
group_components = self.neat_expr(group_components)
names += group_components+","
typeinfos += t+","
s += f"[{self.duplicate_color}]{group_components}[{self.type_color}]({t})"
s += f","
names = names[:-1]
s = s[:-1]
typeinfos = typeinfos[:-1]
representative.module_name = names
representative.type_info = typeinfos
representative.label = s
return representative
def neat_expr(self, l:List[str]):
r"""[NODOC] A small tool function to arrange the consecutive number into interval display.
E.g., ["1","2","3","5","6","9","10","11","12"] -> ["1-3","5-6","9-12"]
"""
try:
s = self.ranges([int(x.strip()) for x in l])
s = [str(x)+"-"+str(y) for x,y in s]
return ",".join(s)
except:
return ",".join(l)
def ranges(self, nums:List[int]):
r"""[NODOC] A small tool function to arrange the consecutive number into interval display.
E.g., [1,2,3,5,6,9,10,11,12] -> [[1,3],[5,6],[9,12]]
"""
nums = sorted(set(nums))
gaps = [[s, e] for s, e in zip(nums, nums[1:]) if s+1 < e]
edges = iter(nums[:1] + sum(gaps, []) + nums[-1:])
return list(zip(edges, edges))
def add_param_info_node(self, m:nn.Module, tree:ModuleTree, record_grad_state=True, record_delta=True):
r"""[NODOC] Add parameter infomation of the module. The parameters that are not inside a module (i.e., created using nn.Parameter) will be added in this function.
"""
known_module = [n for n,c in m.named_children()]
try:
for n,p in m.named_parameters():
if n.split(".")[0] not in known_module:
if len(n.split(".")) > 1: raise RuntimeError(f"The name field {n} should be a parameter since it doesn't appear in named_children, but it contains '.'")
info = "{}:{}".format(n, list(p.shape))
if record_grad_state:
if not p.requires_grad:
color = self.no_grad_color
else:
color = self.param_color
else:
color = self.param_color
if record_delta:
if hasattr(p, "_is_delta") and getattr(p, "_is_delta"):
color = self.delta_color
tree.add(info=info, is_param_node=True, param_color=color)
except:
from IPython import embed; embed(header='in vis')
if __name__=="__main__":
# example command line:
# 1. python opendelta/utils/visualization.py --model t5-lm --model_name_or_path t5-large-lm-adapt --common_structure --only_common
# 2. python opendelta/utils/visualization.py --model roberta --model_name_or_path roberta-large --common_structure
# 3. python opendelta/utils/visualization.py --model gpt2 --model_name_or_path gpt2-medium --keep_non_params --expand_params
from openprompt.plms import load_plm
import argparse
parser = argparse.ArgumentParser("")
parser.add_argument("--model", type=str, default='t5-lm', help="We test both t5 and t5-lm in this scripts, the corresponding tokenizerwrapper will be automatically loaded.")
parser.add_argument("--model_name_or_path", default="t5-large")
parser.add_argument("--cache_base", default='/home/hushengding/plm_cache/')
parser.add_argument("--keep_non_params", action="store_true", help="Display the modules that does not have parameters, such as nn.Dropout")
parser.add_argument("--expand_params", action="store_true", help="Display parameter infomation (shape, etc) in seperate lines. ")
parser.add_argument("--common_structure", action="store_true", help="Whether convert the structure into a common structure defined in structure_mapping.py. The not common structure will be displayed in grey." )
parser.add_argument("--only_common", action="store_true", help="Whether ignore the modules that are not in common structure. This will result in a more compact view. Default to False")
args = parser.parse_args()
plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.cache_base+args.model_name_or_path)
print("Model Loaded!")
if args.common_structure:
from opendelta.utils.structure_mapping import Mappings
mapping = Mappings[args.model]
else:
mapping = None
visobj = Visualization(plm)
visobj.structure_graph(rootname=args.model_name_or_path, keep_non_params=args.keep_non_params, expand_params=args.expand_params, common_structure=args.common_structure, only_common=args.only_common, mapping=mapping)