diff --git a/peewee_plus.py b/peewee_plus.py index 5f2ba9d..fd7b522 100644 --- a/peewee_plus.py +++ b/peewee_plus.py @@ -11,7 +11,7 @@ __url__ = "https://github.com/enpaul/peewee-plus/" __authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"] -__all__ = ["PathField"] +__all__ = ["PathField", "PrecisionFloatField"] class PathField(peewee.CharField): @@ -72,3 +72,33 @@ class PathField(peewee.CharField): if self.relative_to else Path(super().python_value(value)) ) + + +class PrecisionFloatField(peewee.FloatField): + """Field class for storing floats with custom precision parameters + + This field adds support for specifying the ``M`` and ``D`` precision parameters of a + ``FLOAT`` field as specified in the `MySQL documentation`_. + accepts. See the `MySQL docs`_ for more information. + + .. warning:: This field implements syntax that is specific to MySQL. When used with a + different database backend, such as SQLite or Postgres, it behaves identically + to :class:`peewee.FloatField` + + .. note:: This field's implementation was adapted from here_ + + .. _`MySQL documentation`: https://dev.mysql.com/doc/refman/8.0/en/floating-point-types.html + .. _here: https://stackoverflow.com/a/67476045/5361209 + + :param max_digits: Maximum number of digits, combined from left and right of the decimal place, + to store for the value. + :param decimal_places: Maximum number of digits that will be stored after the decimal place + """ + + def __init__(self, *args, max_digits: int = 10, decimal_places: int = 4, **kwargs): + super().__init__(*args, **kwargs) + self.max_digits = max_digits + self.decimal_places = decimal_places + + def get_modifiers(self): + return [self.max_digits, self.decimal_places] diff --git a/tests/test_precision_float_field.py b/tests/test_precision_float_field.py new file mode 100644 index 0000000..d5bea9a --- /dev/null +++ b/tests/test_precision_float_field.py @@ -0,0 +1,31 @@ +# pylint: disable=unused-import +# pylint: disable=redefined-outer-name +# pylint: disable=missing-class-docstring +# pylint: disable=too-few-public-methods +import peewee + +import peewee_plus +from .fixtures import fakedb + + +# There isn't anything we can really test here since this field implements +# a MySQL-specific syntax and we test with SQLite. This test is here just +# to ensure that the behavior is consistent with the normal FloatField when +# working with an unsupported database backend +def test_compatibility(fakedb): + """Check that the precision float field works on sqlite""" + + class TestModel(peewee.Model): + class Meta: + database = fakedb + + precise = peewee_plus.PrecisionFloatField(max_digits=7, decimal_places=3) + imprecise = peewee.FloatField() + + fakedb.create_tables([TestModel]) + + model = TestModel(precise=1234.567, imprecise=1234.567) + model.save() + + model = TestModel.get() + assert model.precise == model.imprecise