mirror of
				https://github.com/enpaul/peewee-plus.git
				synced 2025-11-04 01:08:38 +00:00 
			
		
		
		
	Add enum field for storing enum references in the database
This commit is contained in:
		@@ -1,8 +1,10 @@
 | 
				
			|||||||
 | 
					import enum
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from typing import Any
 | 
					from typing import Any
 | 
				
			||||||
from typing import Dict
 | 
					from typing import Dict
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from typing import Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import peewee
 | 
					import peewee
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -14,7 +16,7 @@ __url__ = "https://github.com/enpaul/peewee-plus/"
 | 
				
			|||||||
__authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"]
 | 
					__authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = ["PathField", "PrecisionFloatField", "JSONField"]
 | 
					__all__ = ["PathField", "PrecisionFloatField", "JSONField", "EnumField"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PathField(peewee.CharField):
 | 
					class PathField(peewee.CharField):
 | 
				
			||||||
@@ -162,3 +164,55 @@ class JSONField(peewee.TextField):
 | 
				
			|||||||
            raise peewee.IntegrityError(
 | 
					            raise peewee.IntegrityError(
 | 
				
			||||||
                f"Failed to decode JSON value from database column '{self.column}'"
 | 
					                f"Failed to decode JSON value from database column '{self.column}'"
 | 
				
			||||||
            ) from err
 | 
					            ) from err
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EnumField(peewee.CharField):
 | 
				
			||||||
 | 
					    """Field class for storing Enums
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This field can be used for storing members of an :class:`enum.Enum` in the database,
 | 
				
			||||||
 | 
					    effectively storing a database reference to a value defined in the application.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    .. warning:: This field ties database data to application structure: if the Enum passed
 | 
				
			||||||
 | 
					                 to this field is modified then the application may encounter errors when
 | 
				
			||||||
 | 
					                 trying to interface with the database schema.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        >>> class MyOptions(enum.Enum):
 | 
				
			||||||
 | 
					        ...    FOO = "have you ever heard the tragedy"
 | 
				
			||||||
 | 
					        ...    BAR = "of darth plageius"
 | 
				
			||||||
 | 
					        ...    BAZ = "the wise?"
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					        >>>
 | 
				
			||||||
 | 
					        >>> class MyModel(peewee.Model):
 | 
				
			||||||
 | 
					        ...    option = EnumField(MyOptions)
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					        >>> m = MyModel(option=MyOptions.FOO)
 | 
				
			||||||
 | 
					        >>> m.save()
 | 
				
			||||||
 | 
					        >>> m.option
 | 
				
			||||||
 | 
					        <MyOptions.FOO: "have you ever heard the tragedy">
 | 
				
			||||||
 | 
					        >>>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :param enumeration: The Enum to accept members of and to use for decoding database values
 | 
				
			||||||
 | 
					    :raises TypeError: If the value to be written to the field is not a member of the
 | 
				
			||||||
 | 
					                       specified Enum
 | 
				
			||||||
 | 
					    :raises peewee.IntegrityError: If the value read back from the database cannot be decoded to
 | 
				
			||||||
 | 
					                                   a member of the specified Enum
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, enumeration: Type[enum.Enum], *args, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					        self.enumeration = enumeration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def db_value(self, value: enum.Enum) -> str:
 | 
				
			||||||
 | 
					        if not isinstance(value, self.enumeration):
 | 
				
			||||||
 | 
					            raise TypeError(f"Enum {self.enumeration.__name__} has no value '{value}'")
 | 
				
			||||||
 | 
					        return super().db_value(value.name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def python_value(self, value: str) -> enum.Enum:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return self.enumeration[super().python_value(value)]
 | 
				
			||||||
 | 
					        except KeyError:
 | 
				
			||||||
 | 
					            raise peewee.IntegrityError(
 | 
				
			||||||
 | 
					                f"Enum {self.enumeration.__name__} has no value with name '{value}'"
 | 
				
			||||||
 | 
					            ) from None
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										53
									
								
								tests/test_enumfield.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								tests/test_enumfield.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,53 @@
 | 
				
			|||||||
 | 
					# pylint: disable=redefined-outer-name
 | 
				
			||||||
 | 
					# pylint: disable=missing-class-docstring
 | 
				
			||||||
 | 
					# pylint: disable=too-few-public-methods
 | 
				
			||||||
 | 
					# pylint: disable=unused-import
 | 
				
			||||||
 | 
					import enum
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import peewee
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import peewee_plus
 | 
				
			||||||
 | 
					from .fixtures import fakedb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_enum(fakedb):
 | 
				
			||||||
 | 
					    """Test basic functionality of the enum field"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class TestEnum(enum.Enum):
 | 
				
			||||||
 | 
					        FOO = "fizz"
 | 
				
			||||||
 | 
					        BAR = "buzz"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class TestModel(peewee.Model):
 | 
				
			||||||
 | 
					        class Meta:
 | 
				
			||||||
 | 
					            database = fakedb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        data = peewee_plus.EnumField(TestEnum)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fakedb.create_tables([TestModel])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = TestModel(data=TestEnum.FOO)
 | 
				
			||||||
 | 
					    model.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = TestModel.get()
 | 
				
			||||||
 | 
					    assert model.data == TestEnum.FOO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class ModifiedEnum(enum.Enum):
 | 
				
			||||||
 | 
					        BAR = "buzz"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class ModifiedModel(peewee.Model):
 | 
				
			||||||
 | 
					        class Meta:
 | 
				
			||||||
 | 
					            table_name = TestModel._meta.table_name  # pylint: disable=protected-access
 | 
				
			||||||
 | 
					            database = fakedb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        data = peewee_plus.EnumField(ModifiedEnum)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(peewee.IntegrityError):
 | 
				
			||||||
 | 
					        ModifiedModel.get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class BadEnum(enum.Enum):
 | 
				
			||||||
 | 
					        NOTHING = "nowhere"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(TypeError):
 | 
				
			||||||
 | 
					        bad = TestModel(data=BadEnum.NOTHING)
 | 
				
			||||||
 | 
					        bad.save()
 | 
				
			||||||
		Reference in New Issue
	
	Block a user