[Numpy-discussion] subclassing matrix
Thu Jan 10 00:08:03 CST 2008
In the course of a project that involved heavy use of geometry and
linear algebra, I found it useful to create a Vector subclass of
numpy.matrix (represented as a column vector in my case).
I'd like to hear comments about my use of this "class promotion"
statement in __new__:
ret.__class__ = cls
It seems to me that it is hackish to just change an instance's class
on the fly, so perhaps someone could clue me in on a better practice.
Here is my reason for doing this:
Many applications of this code involve operations between instances of
numpy.matrix and instances of Vector, such as applying a linear-
operator matrix on a vector. If I omit that "class promotion"
statement, then the results of such operations cannot be instantiated
as Vector types:
>>> from vector import Vector
>>> import numpy
>>> u = Vector('1 2 3')
>>> A = numpy.matrix('2 0 0; 0 2 0; 0 0 2')
>>> p = Vector(A * u)
This is undesirable because the calculation result loses the custom
Vector methods and attributes that I want to use. However, if I use
that "class promotion" statement, the p.__class__ lookup returns what
Is there a better way to achieve that?
Here is the partial subclass code:
import numpy as _N
import math as _M
#default tolerance for equality tests
TOL_EQ = 1e-6
#default format for pretty-printing Vector instances
FMT_VECTOR_DEFAULT = "%+.5f"
2D/3D vector class that supports numpy matrix operations and more.
u = Vector([1,2,3])
v = Vector('3 4 5')
w = Vector([1, 2])
def __new__(cls, data="0. 0. 0.", dtype=_N.float64):
Subclass instance constructor.
If data is not specified, a zero Vector is constructed.
The constructor always returns a Vector instance.
The instance gets a customizable Format attribute, which
controls the printing precision.
ret = super(Vector, cls).__new__(cls, data, dtype=dtype)
#promote the instance to cls type.
ret.__class__ = cls
assert ret.size in (2, 3), 'Vector must have either two or
if ret.shape == 1:
ret = ret.T
assert ret.shape == (ret.shape, 1), 'could not express
Vector as a Mx1 matrix'
if ret.shape == 2:
ret = _N.vstack((ret, 0.))
ret.Format = FMT_VECTOR_DEFAULT
fmt = getattr(self, "Format", FMT_VECTOR_DEFAULT)
fmt = ', '.join([fmt]*3)
return ''.join(["(", fmt, ")"]) % (self.X, self.Y, self.Z)
fmt = ', '.join(['%s']*3)
return ''.join(["%s([", fmt, "])"]) %
(self.__class__.__name__, self.X, self.Y, self.Z)
#### the remaining methods are Vector-specific math operations,
including the X,Y,Z properties...
More information about the Numpy-discussion