{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Basic units\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import itertools\nimport math\n\nfrom packaging.version import parse as parse_version\n\nimport numpy as np\n\nimport matplotlib.ticker as ticker\nimport matplotlib.units as units\n\n\nclass ProxyDelegate:\n def __init__(self, fn_name, proxy_type):\n self.proxy_type = proxy_type\n self.fn_name = fn_name\n\n def __get__(self, obj, objtype=None):\n return self.proxy_type(self.fn_name, obj)\n\n\nclass TaggedValueMeta(type):\n def __init__(self, name, bases, dict):\n for fn_name in self._proxies:\n if not hasattr(self, fn_name):\n setattr(self, fn_name,\n ProxyDelegate(fn_name, self._proxies[fn_name]))\n\n\nclass PassThroughProxy:\n def __init__(self, fn_name, obj):\n self.fn_name = fn_name\n self.target = obj.proxy_target\n\n def __call__(self, *args):\n fn = getattr(self.target, self.fn_name)\n ret = fn(*args)\n return ret\n\n\nclass ConvertArgsProxy(PassThroughProxy):\n def __init__(self, fn_name, obj):\n super().__init__(fn_name, obj)\n self.unit = obj.unit\n\n def __call__(self, *args):\n converted_args = []\n for a in args:\n try:\n converted_args.append(a.convert_to(self.unit))\n except AttributeError:\n converted_args.append(TaggedValue(a, self.unit))\n converted_args = tuple([c.get_value() for c in converted_args])\n return super().__call__(*converted_args)\n\n\nclass ConvertReturnProxy(PassThroughProxy):\n def __init__(self, fn_name, obj):\n super().__init__(fn_name, obj)\n self.unit = obj.unit\n\n def __call__(self, *args):\n ret = super().__call__(*args)\n return (NotImplemented if ret is NotImplemented\n else TaggedValue(ret, self.unit))\n\n\nclass ConvertAllProxy(PassThroughProxy):\n def __init__(self, fn_name, obj):\n super().__init__(fn_name, obj)\n self.unit = obj.unit\n\n def __call__(self, *args):\n converted_args = []\n arg_units = [self.unit]\n for a in args:\n if hasattr(a, 'get_unit') and not hasattr(a, 'convert_to'):\n # If this argument has a unit type but no conversion ability,\n # this operation is prohibited.\n return NotImplemented\n\n if hasattr(a, 'convert_to'):\n try:\n a = a.convert_to(self.unit)\n except Exception:\n pass\n arg_units.append(a.get_unit())\n converted_args.append(a.get_value())\n else:\n converted_args.append(a)\n if hasattr(a, 'get_unit'):\n arg_units.append(a.get_unit())\n else:\n arg_units.append(None)\n converted_args = tuple(converted_args)\n ret = super().__call__(*converted_args)\n if ret is NotImplemented:\n return NotImplemented\n ret_unit = unit_resolver(self.fn_name, arg_units)\n if ret_unit is NotImplemented:\n return NotImplemented\n return TaggedValue(ret, ret_unit)\n\n\nclass TaggedValue(metaclass=TaggedValueMeta):\n\n _proxies = {'__add__': ConvertAllProxy,\n '__sub__': ConvertAllProxy,\n '__mul__': ConvertAllProxy,\n '__rmul__': ConvertAllProxy,\n '__cmp__': ConvertAllProxy,\n '__lt__': ConvertAllProxy,\n '__gt__': ConvertAllProxy,\n '__len__': PassThroughProxy}\n\n def __new__(cls, value, unit):\n # generate a new subclass for value\n value_class = type(value)\n try:\n subcls = type(f'TaggedValue_of_{value_class.__name__}',\n (cls, value_class), {})\n return object.__new__(subcls)\n except TypeError:\n return object.__new__(cls)\n\n def __init__(self, value, unit):\n self.value = value\n self.unit = unit\n self.proxy_target = self.value\n\n def __copy__(self):\n return TaggedValue(self.value, self.unit)\n\n def __getattribute__(self, name):\n if name.startswith('__'):\n return object.__getattribute__(self, name)\n variable = object.__getattribute__(self, 'value')\n if hasattr(variable, name) and name not in self.__class__.__dict__:\n return getattr(variable, name)\n return object.__getattribute__(self, name)\n\n def __array__(self, dtype=object, copy=False):\n return np.asarray(self.value, dtype)\n\n def __array_wrap__(self, array, context=None, return_scalar=False):\n return TaggedValue(array, self.unit)\n\n def __repr__(self):\n return f'TaggedValue({self.value!r}, {self.unit!r})'\n\n def __str__(self):\n return f\"{self.value} in {self.unit}\"\n\n def __len__(self):\n return len(self.value)\n\n if parse_version(np.__version__) >= parse_version('1.20'):\n def __getitem__(self, key):\n return TaggedValue(self.value[key], self.unit)\n\n def __iter__(self):\n # Return a generator expression rather than use `yield`, so that\n # TypeError is raised by iter(self) if appropriate when checking for\n # iterability.\n return (TaggedValue(inner, self.unit) for inner in self.value)\n\n def get_compressed_copy(self, mask):\n new_value = np.ma.masked_array(self.value, mask=mask).compressed()\n return TaggedValue(new_value, self.unit)\n\n def convert_to(self, unit):\n if unit == self.unit or not unit:\n return self\n try:\n new_value = self.unit.convert_value_to(self.value, unit)\n except AttributeError:\n new_value = self\n return TaggedValue(new_value, unit)\n\n def get_value(self):\n return self.value\n\n def get_unit(self):\n return self.unit\n\n\nclass BasicUnit:\n # numpy scalars convert eager and np.float64(2) * BasicUnit('cm')\n # would thus return a numpy scalar. To avoid this, we increase the\n # priority of the BasicUnit.\n __array_priority__ = np.float64(0).__array_priority__ + 1\n\n def __init__(self, name, fullname=None):\n self.name = name\n if fullname is None:\n fullname = name\n self.fullname = fullname\n self.conversions = dict()\n\n def __repr__(self):\n return f'BasicUnit({self.name})'\n\n def __str__(self):\n return self.fullname\n\n def __call__(self, value):\n return TaggedValue(value, self)\n\n def __mul__(self, rhs):\n value = rhs\n unit = self\n if hasattr(rhs, 'get_unit'):\n value = rhs.get_value()\n unit = rhs.get_unit()\n unit = unit_resolver('__mul__', (self, unit))\n if unit is NotImplemented:\n return NotImplemented\n return TaggedValue(value, unit)\n\n def __rmul__(self, lhs):\n return self*lhs\n\n def __array_wrap__(self, array, context=None, return_scalar=False):\n return TaggedValue(array, self)\n\n def __array__(self, t=None, context=None, copy=False):\n ret = np.array(1)\n if t is not None:\n return ret.astype(t)\n else:\n return ret\n\n def add_conversion_factor(self, unit, factor):\n def convert(x):\n return x*factor\n self.conversions[unit] = convert\n\n def add_conversion_fn(self, unit, fn):\n self.conversions[unit] = fn\n\n def get_conversion_fn(self, unit):\n return self.conversions[unit]\n\n def convert_value_to(self, value, unit):\n conversion_fn = self.conversions[unit]\n ret = conversion_fn(value)\n return ret\n\n def get_unit(self):\n return self\n\n\nclass UnitResolver:\n def addition_rule(self, units):\n for unit_1, unit_2 in itertools.pairwise(units):\n if unit_1 != unit_2:\n return NotImplemented\n return units[0]\n\n def multiplication_rule(self, units):\n non_null = [u for u in units if u]\n if len(non_null) > 1:\n return NotImplemented\n return non_null[0]\n\n op_dict = {\n '__mul__': multiplication_rule,\n '__rmul__': multiplication_rule,\n '__add__': addition_rule,\n '__radd__': addition_rule,\n '__sub__': addition_rule,\n '__rsub__': addition_rule}\n\n def __call__(self, operation, units):\n if operation not in self.op_dict:\n return NotImplemented\n\n return self.op_dict[operation](self, units)\n\n\nunit_resolver = UnitResolver()\n\ncm = BasicUnit('cm', 'centimeters')\ninch = BasicUnit('inch', 'inches')\ninch.add_conversion_factor(cm, 2.54)\ncm.add_conversion_factor(inch, 1/2.54)\n\nradians = BasicUnit('rad', 'radians')\ndegrees = BasicUnit('deg', 'degrees')\nradians.add_conversion_factor(degrees, 180.0/np.pi)\ndegrees.add_conversion_factor(radians, np.pi/180.0)\n\nsecs = BasicUnit('s', 'seconds')\nhertz = BasicUnit('Hz', 'Hertz')\nminutes = BasicUnit('min', 'minutes')\n\nsecs.add_conversion_fn(hertz, lambda x: 1./x)\nsecs.add_conversion_factor(minutes, 1/60.0)\n\n\n# radians formatting\ndef rad_fn(x, pos=None):\n if x >= 0:\n n = int((x / np.pi) * 2.0 + 0.25)\n else:\n n = int((x / np.pi) * 2.0 - 0.25)\n\n if n == 0:\n return '0'\n elif n == 1:\n return r'$\\pi/2$'\n elif n == 2:\n return r'$\\pi$'\n elif n == -1:\n return r'$-\\pi/2$'\n elif n == -2:\n return r'$-\\pi$'\n elif n % 2 == 0:\n return fr'${n//2}\\pi$'\n else:\n return fr'${n}\\pi/2$'\n\n\nclass BasicUnitConverter(units.ConversionInterface):\n @staticmethod\n def axisinfo(unit, axis):\n \"\"\"Return AxisInfo instance for x and unit.\"\"\"\n\n if unit == radians:\n return units.AxisInfo(\n majloc=ticker.MultipleLocator(base=np.pi/2),\n majfmt=ticker.FuncFormatter(rad_fn),\n label=unit.fullname,\n )\n elif unit == degrees:\n return units.AxisInfo(\n majloc=ticker.AutoLocator(),\n majfmt=ticker.FormatStrFormatter(r'$%i^\\circ$'),\n label=unit.fullname,\n )\n elif unit is not None:\n if hasattr(unit, 'fullname'):\n return units.AxisInfo(label=unit.fullname)\n elif hasattr(unit, 'unit'):\n return units.AxisInfo(label=unit.unit.fullname)\n return None\n\n @staticmethod\n def convert(val, unit, axis):\n if np.iterable(val):\n if isinstance(val, np.ma.MaskedArray):\n val = val.astype(float).filled(np.nan)\n out = np.empty(len(val))\n for i, thisval in enumerate(val):\n if np.ma.is_masked(thisval):\n out[i] = np.nan\n else:\n try:\n out[i] = thisval.convert_to(unit).get_value()\n except AttributeError:\n out[i] = thisval\n return out\n if np.ma.is_masked(val):\n return np.nan\n else:\n return val.convert_to(unit).get_value()\n\n @staticmethod\n def default_units(x, axis):\n \"\"\"Return the default unit for x or None.\"\"\"\n if np.iterable(x):\n for thisx in x:\n return thisx.unit\n return x.unit\n\n\ndef cos(x):\n if np.iterable(x):\n return [math.cos(val.convert_to(radians).get_value()) for val in x]\n else:\n return math.cos(x.convert_to(radians).get_value())\n\n\nunits.registry[BasicUnit] = units.registry[TaggedValue] = BasicUnitConverter()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.2" } }, "nbformat": 4, "nbformat_minor": 0 }