前言
半夜coding写bug发现一个很冷僻的Python嵌套函数的坑点,先发出来给大家看看有没有遇到过类似的问题,暂时还不是很搞得清楚原理是什么。
不知道大家写嵌套函数的时候有没有加下划线作为变量名称前缀的习惯,笔者的习惯一直都是每嵌套一层,嵌套函数的名称及其所有的新变量都额外添加一个下划线作为前缀,这样的好处是可以确保内外变量名不会发生重复,从而防止嵌套函数内意外改动外层函数的变量值。
比如下面这种代码结构的写法就比较满足我的强迫症(凑字数):
def generate_dataloader(args, mode='train', do_export=False, pipeline='judgment', for_debug=False):
dataset = BasicDataset(args=args,
mode=mode,
do_export=do_export,
pipeline=pipeline,
for_debug=for_debug)
column = dataset.data.columns.tolist()
if mode.startswith('train'):
batch_size = args.train_batch_size
shuffle = True
if mode.startswith('valid'):
batch_size = args.valid_batch_size
shuffle = False
if mode.startswith('test'):
batch_size = args.test_batch_size
shuffle = False
def _collate_fn(_batch_data):
def __collate_id():
return [__data['id'] for __data in _batch_data]
def __collate_type():
return [__data['type'] for __data in _batch_data]
def __collate_subject():
return torch.LongTensor([__data['subject'] for __data in _batch_data])
def __collate_label_choice():
return torch.LongTensor([__data['label_choice'] for __data in _batch_data])
def __collate_label_judgment():
return torch.LongTensor([__data['label_judgment'] for __data in _batch_data])
def __collate_option_id():
return [__data['option_id'] for __data in _batch_data]
if args.word_embedding is None and args.document_embedding is None:
def __collate_question():
return torch.LongTensor([__data['question'] for __data in _batch_data])
def __collate_reference():
return torch.LongTensor([__data['reference'] for __data in _batch_data])
def __collate_options():
return torch.LongTensor([__data['options'] for __data in _batch_data])
def __collate_option():
return torch.LongTensor([__data['option'] for __data in _batch_data])
else:
def __collate_question():
if isinstance(_batch_data[0]['question'], numpy.ndarray):
return torch.FloatTensor([__data['question'] for __data in _batch_data])
elif isinstance(_batch_data[0]['question'], torch.Tensor):
return torch.stack([__data['question'] for __data in _batch_data])
else:
raise NotImplementedError(type(_batch_data[0]['question']))
def __collate_reference():
if isinstance(_batch_data[0]['reference'], numpy.ndarray):
return torch.FloatTensor([__data['reference'] for __data in _batch_data])
elif isinstance(_batch_data[0]['reference'], torch.Tensor):
return torch.stack([__data['reference'] for __data in _batch_data])
else:
raise NotImplementedError(type(_batch_data[0]['reference']))
def __collate_options():
if isinstance(_batch_data[0]['options'], numpy.ndarray):
return torch.FloatTensor([__data['options'] for __data in _batch_data])
elif isinstance(_batch_data[0]['options'], torch.Tensor):
return torch.stack([__data['options'] for __data in _batch_data])
elif isinstance(_batch_data[0]['options'], list):
if isinstance(_batch_data[0]['options'][0], numpy.ndarray):
return torch.FloatTensor(numpy.stack([numpy.stack(__data['options']) for __data in _batch_data]))
elif isinstance(_batch_data[0]['options'][0], torch.Tensor):
return torch.stack([torch.stack(__data['options']) for __data in _batch_data])
else:
raise NotImplementedError(type(_batch_data[0]['options'][0]))
else:
raise NotImplementedError(type(_batch_data[0]['options']))
def __collate_option():
if isinstance(_batch_data[0]['option'], numpy.ndarray):
return torch.FloatTensor([__data['option'] for __data in _batch_data])
elif isinstance(_batch_data[0]['option'], torch.Tensor):
return torch.stack([__data['option'] for __data in _batch_data])
else:
raise NotImplementedError(type(_batch_data[0]['option']))
if args.use_pos_tags:
def __collate_pos_tags(__column):
return torch.LongTensor([list(map(lambda __pos_tag: STANFORD_POS_TAG_INDEX.get(__pos_tag, -1) + 1, __data[__column])) for __data in _batch_data])
def __collate_statement_pos_tags():
return __collate_pos_tags(__column='statement_pos_tags')
def __collate_option_a_pos_tags():
return __collate_pos_tags(__column='option_a_pos_tags')
def __collate_option_b_pos_tags():
return __collate_pos_tags(__column='option_b_pos_tags')
def __collate_option_c_pos_tags():
return __collate_pos_tags(__column='option_c_pos_tags')
def __collate_option_d_pos_tags():
return __collate_pos_tags(__column='option_d_pos_tags')
if args.use_reference:
def __collate_reference_pos_tags():
return torch.LongTensor([[list(map(lambda ___pos_tag: STANFORD_POS_TAG_INDEX.get(___pos_tag, -1) + 1, ___pos_tags)) for ___pos_tags in __data['reference_pos_tags']] for __data in _batch_data])
if args.use_parse_tree:
def __collate_parse_tree(__column):
return [[parse_tree_to_graph(__parse_tree) for __parse_tree in __data[__column]] for __data in _batch_data]
def __collate_statement_tree():
return __collate_parse_tree(__column='statement_tree')
def __collate_option_a_tree():
return __collate_parse_tree(__column='option_a_tree')
def __collate_option_b_tree():
return __collate_parse_tree(__column='option_b_tree')
def __collate_option_c_tree():
return __collate_parse_tree(__column='option_c_tree')
def __collate_option_d_tree():
return __collate_parse_tree(__column='option_d_tree')
if args.use_reference:
def __collate_reference_tree():
return [[[parse_tree_to_graph(___parse_tree) for ___parse_tree in ___parse_trees] for ___parse_trees in __data['reference_tree']] for __data in _batch_data]
_collate_data = {}
for _column in column:
_collate_data[_column] = eval(f'__collate_{_column}')()
return _collate_data
dataloader = DataLoader(dataset=dataset,
batch_size=batch_size,
num_workers=args.num_workers,
collate_fn=_collate_fn,
shuffle=shuffle)
return dataloader
本来一切都很完美,直到开始出现问题。
问题发现
下面是一个出现坑的demo:
class Dataset:
def __init__(self):
pass
def demo(self):
def _easy_plus(_x):
def __easy_plus(__y):
return _x + __y
return __easy_plus
one_plus_function = _easy_plus(_x=1)
help(one_plus_function)
print(one_plus_function(__y=2))
dataset = Dataset()
dataset.demo()
输出结果如下:
Help on function __easy_plus in module __main__:
__easy_plus(_Dataset__y)
Traceback (most recent call last):
File "sanity_test.py", line 21, in <module>
dataset.demo()
File "sanity_test.py", line 18, in demo
print(one_plus_function(__y=1))
TypeError: __easy_plus() got an unexpected keyword argument '__y'
可以发现__easy_plus 函数的参数名称事实上被改动成为_Dataset__y ,调用one_plus_function(__y=2) 会报错,调用one_plus_function(_Dataset__y=2) 则可以正常返回结果。
但是如果将函数_easy_plus 与__easy_plus 的参数名称前缀的下划线分别去掉一个,改动成下面的形式:
class Dataset:
def __init__(self):
pass
def demo(self):
def _easy_plus(x):
def __easy_plus(_y):
return x + _y
return __easy_plus
one_plus_function = _easy_plus(x=1)
help(one_plus_function)
print(one_plus_function(_y=2))
dataset = Dataset()
dataset.demo()
输出结果如下:
Help on function __easy_plus in module __main__:
__easy_plus(_y)
3
可以发现参数名称又没有发生修改了。
通过观察规律,似乎是编译器默认参数名称包含很多下划线前缀时就会自动修改名称,于是笔者又写了下面的demo来验证这一规律:
class Dataset:
def __init__(self):
pass
def demo(self):
def _easy_plus(__x):
def __easy_plus(___y):
return __x + ___y
return __easy_plus
help(_easy_plus)
one_plus_function = _easy_plus(_Dataset__x=1)
help(one_plus_function)
dataset = Dataset()
dataset.demo()
输出结果如下:
Help on function _easy_plus in module __main__:
_easy_plus(_Dataset__x)
Help on function __easy_plus in module __main__:
__easy_plus(_Dataset___y)
的确如此,结果显示连_easy_plus 的参数名称都被修改了。
然而如果嵌套函数是写在嵌套函数之外呢?
def demo():
def _easy_plus(__x):
def __easy_plus(___y):
return __x + ___y
return __easy_plus
help(_easy_plus)
one_plus_function = _easy_plus(__x=1)
help(one_plus_function)
demo()
输出结果如下:
Help on function _easy_plus in module __main__:
_easy_plus(__x)
Help on function __easy_plus in module __main__:
__easy_plus(___y)
此时不管有多少下划线前缀也不会发生参数名称被修改的现象。
现在只能姑且总结规律为:类函数中的嵌套函数参数名称的下划线前缀超过两个就会触发参数名称修改。
虽然并不是很能理解确切的原理是什么… [Facepalm]
规律更新
其实后来发现其实不止是嵌套函数,只要是写在类中的函数,不管是静态方法还是别的对象方法,只要参数名称的下划线前缀超过两个都会触发这种机制的修改。比如下面的示例:
class Dataset:
def __init__(self):
pass
def _demo(self, __x):
return __x
dataset = Dataset()
help(dataset._demo)
输出结果为:
Help on method _demo in module __main__:
_demo(_Dataset__x) method of __main__.Dataset instance
因此规律应该是:类域内的函数参数名称的下划线前缀超过两个就会触发参数名称修改。
关于Python的类函数,因为没有明确的public 与private 的区分,因此一般默认函数名称会使用下划线作为前缀来区分是否应当被外部调用(即是否为私有的,虽然这也并不是强制的,想要调用带下划线作为名称前缀的类函数依然是可行的,比如对于list 类型的变量token_list 来说,调用token_list[0] 与token_list.__getitem__(0) 是完全等价的),但是关于函数的参数名称会被修改的确是从来没有注意到过的事情,希望我不是最后一个发现这个问题的倒霉球…[Facepalm]
问题解决
本着闲得蛋疼求真务实的精神,笔者还是去查了一下官方文档,总结下来确实跟上面猜想的一样,虽然Python类中会用下划线前缀来不严格的区分私有方法,但是用两个下划线作为前缀确实是严格的区分了私有变量。具体原理如下(其中第2点解释了原理,第345点说明了这种参数名称修改的原因):
(摘自https://docs.python.org/zh-cn/3/tutorial/classes.html#private-variables)
-
那种仅限从一个对象内部访问的私有实例变量在Python中并不存在。但是,大多数Python代码都遵循这样一个约定:带有一个下划线的名称(例如_spam )应该被当作是API的非公有部分 (无论它是函数、方法或是数据成员)。这应当被视为一个实现细节,可能不经通知即加以改变。 -
由于存在对于类私有成员的有效使用场景(例如避免名称与子类所定义的名称相冲突),因此存在对此种机制的有限支持,称为名称改写(就是这个玩意儿)。 任何形式为__spam 的标识符(至少带有两个前缀下划线,至多一个后缀下划线)的文本将被替换为_classname__spam ,其中classname 为去除了前缀下划线的当前类名称。这种改写不考虑标识符的句法位置,只要它出现在类定义内部就会进行。 -
名称改写有助于让子类重载方法而不破坏类内方法调用。例如: class Mapping:
def __init__(self, iterable):
self.items_list = []
self.__update(iterable)
def update(self, iterable):
for item in iterable:
self.items_list.append(item)
__update = update
class MappingSubclass(Mapping):
def update(self, keys, values):
for item in zip(keys, values):
self.items_list.append(item)
-
上面的示例即使在MappingSubclass 引入了一个__update 标识符的情况下也不会出错,因为它会在Mapping 类中被替换为_Mapping__update 而在MappingSubclass 类中被替换为_MappingSubclass__update 。 -
请注意,改写规则的设计主要是为了避免意外冲突;访问或修改被视为私有的变量仍然是可能的。这在特殊情况下甚至会很有用,例如在调试器中。 -
请注意传递给exec() 或eval() 的代码不会将发起调用类的类名视作当前类;这类似于global 语句的效果,因此这种效果仅限于同时经过字节码编译的代码。 同样的限制也适用于getattr() ,setattr() 和delattr() ,以及对于__dict__ 的直接引用。
后记
其实这个问题在不知情的情况下确实很难被发现,而且笔者发现按照上面的说法是不是其实__easy_plus 这个函数其实也被改名了,只是目前还没有发现它可能报错的点。总之在类中得要慎用双下划线前缀的写法了,强迫症该改还是得改。
|