Python面向对象编程08-运算符重载

运算符重载

运算符重载的概念

在 Python 中,可以对两个列表使用加法运算符 + ,例如:

>>> [1, 2] + [5, 6] [1, 2, 5, 6]

结果显示,加法运算符可以拼接两个序列。还可以对序列和整数使用乘法运算符 * 例如:

>>> (2, 3) * 3 (2, 3, 2, 3, 2, 3)

结果显示,序列和整数的“乘法运算”表示重复序列。

这种设计虽然方便了序列的操作,但是看起来完全背离了运算符的数学本义:为什么对两个序列使用加法运算符的结果不是序列内的元素对应相加?为什么对序列和整数使用乘法运算符的结果不是序列内的元素和整数相乘?实际上,Python 有一个第三方库 NumPy 就是这样设计的:NumPy 中的两个数组相加就是将数组的元素相加,数组和数值相乘就是将元素和数值相乘:

>>> import numpy as np >>> np.array([1, 2]) + np.array([5, 6]) array([6, 8]) >>> np.array([2, 3, 4]) * 3 array([ 6, 9, 12])

前几节介绍了 property__getattr__ 等特殊方法,既然它们都可以彻底改变对象的属性查找行为,那么显然也有类似的特殊方法决定了一个对象面对运算符的行为。这就是运算符重载的概念,Python 可以对特定的对象实现运算符重载,使用运算符来表示对象的一些操作方式,从而使对象操作更加优雅。

对于 Python 用的不多的数据类型 set ,其支持的运算符操作更多。例如,按位与运算符 & 表示对两个集合取交集(同时属于两个集合的元素):

>>> {1, 2, 4, 5} & {2, 3, 5, 7} {2, 5}

或者使用按位或 | 运算符表示两个集合的并集:

>>> {1, 2, 4, 5} | {2, 3, 5, 7} {1, 2, 3, 4, 5, 7}

除此之外,减法运算符 - 表示差集,异或运算符 ^ 表示对称差集。这种表达方式既便于书写,也便于阅读。

有了运算符重载的概念后,接下来介绍运算符重载的方法。

如何重载运算符

下面以最简单的加法运算符 + 来说明运算符是如何完成重载的。

Python 的运算符实质上是通过调用对象的特殊方法来实现的。对于加法运算符来说,对应的重载方法是 .__add__()

接下来先介绍的都是二元运算符。在这种情况下,运算符操作的都是两个对象:左操作数和右操作数。因此,对应运算符的特殊方法也需要拥有两个参数: 一个是 self ,代表用于引出运算的左操作数;另一个参数常记为 other ,代表用于参与运算的右操作数。因此,该方法的完整参数列表为 .__add__(self, other) ,最后只需要在该方法内返回一个合适的值,该值便作为运算符运算后的结果。

这里需要注意,.__add__() 重载方法会被运算左侧操作数调用,也就是说需要由左操作数实现这个方法。因此参数 self 就代表左操作数实例本身,而右操作数仅仅是作为第二个参数传入,即运算 x + y 的实质是调用 x.__add__(y)

下表列出了所有的二元运算符及其重载方法:

运算符 对应方法 本义 说明
+ __add__() 加法
- __sub__() 减法
* __mul__() 乘法
@ __matmul__() 矩阵相乘 Python3.5 开始支持
/ __truediv__() 除法
// __floordiv__() 整数相除
% __mod__() 取余
divmod() __divmod__() 除法并取余 由函数支持实现,而非运算符
**pow() __pow__() 乘方 由于 pow() 函数支持第三个参数,该方法需要支持第三个参数 modulo 。详见pow() 函数的描述
<< __lshift__() 左移位
>> __rshift__() 右移位
& __and__() 按位与
| __or__() 按位或
^ __xor__() 按位异或

例如,以下实现了一个针对三维向量运算的 Vector3D 类,针对向量加法、减法和叉乘实现了运算符重载,并使用 __str__ 特殊方法让结果更清晰:

class Vector3D:
    def __add__(self, other):
        return Vector3D(self.x + other.x, self.y + other.y, self.z + other.z)
    def __sub__(self, other):
        return Vector3D(self.x - other.x, self.y - other.y, self.z - other.z)
    def __matmul__(self, other):
        return Vector3D(self.y * other.z - self.z * other.y,
                        self.z * other.x - self.x * other.z,
                        self.x * other.y - self.y * other.x)
    def __str__(self):
        return f'Vector({self.x}, {self.y}, {self.z})'
    __repr__ = __str__

这样就可以直接用运算符操作向量实例了:

>>> v01 = Vector3D(3, 5, 1) >>> v02 = Vector3D(2, 4, 3) >>> v01 + v02 Vector(5, 9, 4) >>> v01 - v02 Vector(1, 1, -2) >>> v01 @ v02 Vector(11, -7, 2)

重载运算符时,它的运算对象可能是不同类型的数据。例如向量的乘法,和它相乘的对象可能是一个数值,也可能是另一个向量。这时就需要在方法内使用 isinstance() 函数判断另一个对象的类型,并使用不同的计算方式:

from numbers import Real
   
class Vector3D:
    def __mul__(self, other):
        if isinstance(other, Real):
            return Vector3D(self.x * other, self.y * other, self.z * other)
        elif isinstance(other, Vector3D):
            return self.x * other.x + self.y * other.y + self.z * other.z

这样,在应用 * 运算符时,Vector3D 可以根据右侧操作数的类型采用不同的计算方式:向量乘向量则返回点积,向量乘以一个数则返回标量积:

>>> v01 * v02 29 >>> v01 * 2 Vector(6, 10, 2)

Python 还对运算符重载过程中的一个细节作了规定:如果一个对象不知道该如何处理右侧操作数,那么它应该返回一个特殊的对象 NotImplemented 。Python 解释器在接收到这个特殊的对象后,就会自动对这次运算结果抛出 TypeError 异常,并自动生成合适的异常信息。因此,以上实现的 .__mul__() 最后还需要加上这么一句:

        else:
            return NotImplemented

这样程序就能处理不合理的乘法操作了:

>>> v01 * [2, 3, 4] Traceback (most recent call last): File "<stdin>", line 1, in <module> TypeError: can't multiply sequence by non-int of type 'Vector3D'

注意区分 NotImplemented 对象和 NotImplementedError 异常,后者用在函数或方法中,提示使用者该函数并未实现,或者方法必须要被重写。

如果一个重载的运算符可以参与对许多类型的运算,那么就需要写很多 elif 条件。这种情况下,可以使用 functools 模块中的单分派函数 singledispatchmethod(因为运算符重载只有一个额外参数,非常适合单分派的应用场合)。

使用单分派函数时,在方法内部只需要简单返回一个 NotImplemented ,作为所有类型匹配失败之后的 fallback 机制:

from functools import singledispatchmethod
   
class Vector3D:
    @singledispatchmethod
    def __mul__(self, other):
        return NotImplemented

然后在附近找个地方注册一下不同类型对应的分派函数:

from collections.abc import Sequence

@Vector3D.__mul__.register
def __dispatch(self, other: Real):
    return Vector3D(self.x * other, self.y * other, self.z * other)
@Vector3D.__mul__.register
def __dispatch(self, other: Vector3D):
    return self.x * other.x + self.y * other.y + self.z * other.z
@Vector3D.__mul__.register
def __dispatch(self, other: Sequence):
    return self.x * other[0] + self.y * other[1] + self.z * other[2]

现在这个向量类不仅支持和序列相乘(将序列看作一个向量),而且以后如果还需要继续扩展支持的类型也会方便得多。

反向重载运算符

在有了运算符重载的知识以后,就可以明白为什么同样是乘以一个整数,Python 内置的列表和一些第三方库的数组结果完全不同,原因是列表和数值实现了不同的 .__mul__() 方法,它们分别实现了不同的处理逻辑。

但是还有一个问题没有解决:之前实现的重载运算符调用的都是由左操作数实现相应的方法,这就意味着交换两个操作数的位置,运算符可能会不满足交换律,甚至直接出现错误:

>>> 2 * v01 Traceback (most recent call last): File "<stdin>", line 1, in <module> TypeError: unsupported operand type(s) for *: 'int' and 'Vector3D'

为了使运算符满足交换律,每针对两种不同类型重载一次运算符,就需要在左右操作数两侧类型都补充针对这种操作符的处理逻辑。但很显然,这种方法不仅费时费力,而且效率低下,尤其是对内置类型不好处理。

这就涉及到 Python 运算符的反向重载问题了:当一个对象通过运算符与右侧操作数运算时,如果发现该对象没有重载对应的运算符,那么会尝试调用右侧操作数对象的反向重载运算符方法。下表列出了所有运算符对应的反向重载运算符方法:

运算符 对应反向重载方法 运算符 对应反向重载方法
+ __radd__() divmod() __rdivmod__()
- __rsub__() **pow() __rpow__()
* __rmul__() << __rlshift__()
@ __rmatmul__() >> __rrshift__()
/ __rtruediv__() & __rand__()
// __rfloordiv__() ^ __rxor__()
% __rmod__() | __ror__()

可以看出,运算符对应的反向重载方法就是相比起普通重载方法的名称多了一个 r(代表 reflected ,也可以理解为 reversed 或 right )。

假设进行运算 x + y(或者任意的二元运算符),首先会尝试调用运算方法 x.__add__(y) ;如果发现 x 没有实现方法 .__add__() 或者该方法返回了 NotImplemented ,那么会继续尝试调用 y 的反向重载方法 .__radd__() ,计算 y.__radd__(x) 。如果 y 类型的 .__radd__() 也不存在或返回了 NotImplemented ,那就确实不支持两者的运算了。

所以,重载运算符返回 NotImplemented 的意义就在于触发反向重载机制。如果直接抛出 TypeError 异常,就不会触发反向重载机制了。

一般来说,满足运算交换律的类可以简单地将反向重载方法作为重载方法的一个引用:

class Vector3D:    
    __radd__ = __add__
    __rmul__ = __mul__

现在任意顺序的运算都能正常处理了:

>>> 2 * v01 Vector(6, 10, 2)

Python 解释器在处理反向运算符重载时还有一些比较有意思的细节:

  • 传入了三个参数的 pow() 函数不会试图调用反向重载的 .__rpow__() 方法(官方文档说这样处理逻辑太复杂了)
  • 显式调用运算对象的重载方法会因为方法不存在而引起 AttributeError ,或者在不支持运算时直接得到变量 NotImplemented ,也不会发生任何反向计算过程
  • 如果右侧运算对象的类型是左侧运算对象的子类,并且子类实现了反向运算符重载运算方法,那么右侧运算对象的反向重载方法会优先于左侧的重载方法调用

最后一点的意义是,继承的子类可以通过反向重载运算符覆盖父类的运算方法。例如,假设要让 Python 列表的相加就表示对应元素相加而不是列表拼接,那么只需要继承内置的 list 并重写两个加法方法,这样在任何顺序下都以子类的重载方法为准:

from operator import add
from itertools import starmap

class ArithmeticList(list):
    def __add__(self, other: Sequence):
        return ArithmeticList(starmap(add, zip(self, other)))
    __radd__ = __add__

效果为:

>>> [3, 4, 6] + ArithmeticList([1, 5, 2]) [4, 9, 8] >>> ArithmeticList([1, 5, 2]) + [3, 4, 6] [4, 9, 8]

赋值运算符的重载

不管是运算符方法还是反向运算符方法,它们都不应该修改调用者自身(即 self )的属性,而是应该返回一个新的对象。如果要修改自身,程序应该明确使用赋值运算符,例如使用 += 代替 +

尽管从理论上来说,二元赋值运算符如 x += y 的实质就是 x = x + y ,两者的区别在于 += 可以只修改自身而不创建新的对象,因此 Python 也支持重载这些二元赋值运算符。

下表列出了所有的二元赋值运算符及其重载方法:

运算符 对应重载方法 运算符 对应重载方法
+= __iadd__() **= __ipow__()
-= __isub__() <<= __ilshift__()
*= __imul__() >>= __irshift__()
@= __imatmul__() &= __iand__()
/= __itruediv__() ^= __ixor__()
//= __ifloordiv__() |= __ior__()
%= __imod__()

这些二元赋值运算符对应的重载方法实际上就是在对应的二元运算符名称前加上了一个 i

一般来说,二元赋值运算 x += y 会尝试调用对应的重载方法 x.__iadd__(y) ,如果该类没有实现赋值运算符的重载方法,或者该重载方法返回了 NotImplemented 那么会尝试将二元赋值表达式展开成 x = x + y ,对 x + y 尝试依次调用 x.__add__(y)y.__radd__(x) 。(注意:由于一个小 bug ,二元赋值运算符 **= 对应的方法 .__ipow__() 如果返回了 NotImplemented ,不会试图调用 .__pow__().__rpow__() ,这个小 bug 直到 Python3.10 才修复)

注意,虽然上文说的是赋值运算符应该在方法内修改自身,但是因为有些对象是不可变类型,所以 .__iadd__() 等方法还需要提供返回值,这个返回值才是左操作符最终的赋值结果。例如,以下是对 Vector3D 重载赋值运算的示例:

class Vector3D:    
    def __iadd__(self, other):
        self.x += other.x
        self.y += other.y
        self.z += other.z
        return self
    def __imul__(self, other: Real):
        self.x *= other
        self.y *= other
        self.z *= other
        return self

效果为:

>>> v01 *= 2 >>> v01 Vector(2, 2, -4)

这一过程实际上可能存在一些奇怪的现象,Python 的官方文档也着重说明了这个问题。考虑以下元组:

>>> tup = ([1, 2], 3)

元组是不可变类型,但列表是可变类型,这为赋值运算符的问题埋下了伏笔。尝试对元组内的列表应用赋值运算符:

>>> tup[0] += [4] Traceback (most recent call last): File "<stdin>", line 1, in <module> TypeError: 'tuple' object does not support item assignment

明明是对列表应用的赋值运算符,但元组却产生了错误。不仅如此,再次查看元组会发现即便产生了错误,列表也确实被修改了:

>>> tup ([1, 2, 4], 3)

这个现象的原因是 tup[0] += [4] 实际调用的应该是 tup[0].__iadd__([4])tup[0] 是一个列表,它使用 .__iadd__() 方法相当于使用 .extend() 扩充另一个列表。问题就在于调用完成 .__iadd__() 后,会试图将返回值赋值给 tup[0] 元素,而元组元素是不能直接赋值的,这就产生了错误。如果不是使用赋值运算符,而是直接调用 .entend() 方法就不会出现错误了:

>>> tup[0].extend([5, 6]) >>> tup ([1, 2, 4, 5, 6], 3)

一元运算符的重载

除了二元运算符外,一元运算符也可以重载。可以重载的一元运算符和它们的重载方法如下表所示:

运算符 对应方法 本义 说明
- __neg__() 取负
+ __pos__() 取正 由于该方法可以被重载,所以有些对象可能会发生 x != +x 的情况
abs() __abs__() 取绝对值 由函数而不是运算符实现
~ __invert__() 按位非(按位取反)
round(ndigits) __round__() 舍入 由函数而不是运算符实现
math.floor() __floor__() 向上取整 由标准库中的函数而不是运算符实现
math.ceil() __ceil__() 向下取整 由标准库中的函数而不是运算符实现

__int__ 这类被 int() 等构造函数调用,实现向对应类型转换的特殊方法,其实也可以归为这一类,不过它们均在前面的章节介绍过了。

一元运算符相应的特殊方法只有一个参数 self ,并且一元运算符重载时,也不应该修改自身,而要创建并返回适合类型的新实例。以下是对 Vector3D 进一步支持一元运算符的相关代码:

class Vector3D:    
    def __abs__(self):
        return sqrt(self.x ** 2 + self.y ** 2 + self.z ** 2)
    def __pos__(self):
        return Vector3D(self.x, self.y, self.z)
    def __neg__(self):
        return Vector3D(-self.x, -self.y, -self.z)

以及使用示例:

>>> -v01 Vector(-2, -2, 4) >>> abs(v01) 4.898979485566356

关系运算符的重载

除了算术和赋值,二元关系运算符也支持重载。下表列出了二元关系运算符及其对应的重载方法:

运算符 对应重载方法 含义
== __eq__() 相等
!= __ne__() 不相等
>= __ge__() 大于等于
<= __le__() 小于等于
> __gt__() 大于
< __lt__() 小于

注意:Python 不允许重载 is 运算符,也不允许重载 notandor 这类布尔运算符。is 只用于判断两个变量指向的是否为同一个对象,而对象之间的相等判断可以通过重载 == 运算符实现。如果要改变布尔运算符的结果,可以通过 __bool__() 方法改变对象的布尔行为,或者使用 ~&| 这三个运算符代替。

这些关系运算符没有所谓的“赋值”版本,甚至也没有所谓的“反向”版本;不过,__eq__()__ne__() 其实就是自身的反向版本,而 __lt__()__gt__()__le__()__ge__() 是彼此的反向版本:当 x.__eq__(y)x.__lt__(y) 不存在或返回了 NotImplemented ,就会尝试从右操作数调用 y.__eq__(x)y.__gt__(x)

基类 object 实现了 __eq__() 方法,不过它的判断逻辑非常简单:如果两者使用 is 测试相等,就返回 True ,否则返回 NotImplementedobject 也实现了 __ne__() 方法,它的逻辑就是对 __eq__() 的结果取反。

可以重载 Vector3D 类的相等比较运算符,实现向量相等的比较逻辑:

class Vector3D:    
    def __eq__(self, other):
        return self.x == other.x and self.y == other.y and self.z == other.z
>>> Vector3D(1.0, 2.0, 3.0) == Vector3D(1, 2, 3) True >>> Vector3D(1.01, 2.0, 3.0) == Vector3D(1, 2, 3) False

一般情况下不需要重载 != 运算符,让它自己取反就行了。同时重载 ==!= 运算符可能会使两种判断结果都返回 True ,违背了数学直觉。

更复杂的情况发生在关系运算符上:Python3 并没有在基类上实现关系运算符的处理(其实是有这些方法,只不过它们直接返回 NotImplemented ),也不会将 >= 运算符视为 >== 的组合。为了使两个对象能够参与比较,至少要在类中实现相等的判断,以及不带等号的比较和带等号的比较各一个;如果两个对象属于不同的类,可能要将这四个比较运算符全部实现。

但其实,如果返回结果是符合数学直觉的,那么一个类只要实现 == 运算符以及四个比较运算符的其中一个,其它结果是可以根据它们推导出来的。例如,实现了大于运算符 > 后,剩下运算符的处理方式为:

class Comparable:
    def __ge__(self, other):
        return self > other or self == other
    def __lt__(self, other):
        return not self >= other
    def __le__(self, other):
        return not self > other

但是,一个类究竟应该实现这四个方法中的哪一个?不同的人可能会有不同的习惯,但不知道一个类实现了哪个方法,其余方法的实现也就无从下手。为此,Python3.2 在 functools 标准库中引入了一个特殊的装饰类的装饰器 total_ordering ,只要类实现了 == 以及四个关系运算符中的一个,它就可以自动添加剩余的关系运算符方法:

from functools import total_ordering

@total_ordering
class StudentScore:
    def __init__(self, math, physics, chemistry):
        self.math = math
        self.physics = physics
        self.chemistry = chemistry
    def __eq__(self, other):
        if self._has_score(other):
            return (self.math == other.math and self.physics == other
                    and self.chemistry == other.chemistry)
        else:
            return NotImplemented
    def __gt__(self, other):
        if self._has_score(other):
            return ((self.math, self.physics, self.chemistry) >
                    (other.math, other.physics, other.chemistry))
        else:

这样,StudentScore 类就可以支持与其它类型对象的比较行为了:

>>> bob = StudentScore(94, 82, 87) >>> alice = StudentScore(89, 92, 97) >>> tom = StudentScore(89, 80, 86) >>> bob > alice True >>> alice < tom False >>> class Colledge: ... >>> entrance_line = Colledge() >>> entrance_line.math = 70 >>> entrance_line.physics = 70 >>> entrance_line.chemistry = 70 >>> alice > entrance_line True >>> entrance_line <= tom True

参考资料/延伸阅读

Python3 DataModel 官方文档对运算符重载的介绍

Python3 标准库 functools 官方文档

京ICP备2021034974号
contact me by hello@frozencandles.fun