From 28c7e95a9e8a23b449580e6390495588466f5cde Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Fri, 1 Oct 2021 10:56:19 +0330 Subject: [PATCH 1/3] fix create method for non-integer pk --- orm/models.py | 10 +++++----- tests/test_columns.py | 14 +++++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/orm/models.py b/orm/models.py index aec400e..77b117e 100644 --- a/orm/models.py +++ b/orm/models.py @@ -382,7 +382,6 @@ async def first(self, **kwargs): return rows[0] async def create(self, **kwargs): - # Validate the keyword arguments. fields = self.model_cls.fields validator = typesystem.Schema( fields={key: value.validator for key, value in fields.items()} @@ -393,11 +392,12 @@ async def create(self, **kwargs): if value.validator.read_only and value.validator.has_default(): kwargs[key] = value.validator.get_default_value() - # Build the insert expression. - expr = self.table.insert() - expr = expr.values(**kwargs) + if self.model_cls.database.url.dialect == "sqlite": + expr = self.table.insert().values(**kwargs) + else: + pk_column = getattr(self.table.c, self.pkname) + expr = self.table.insert().values(**kwargs).returning(pk_column) - # Execute the insert, and return a new model instance. instance = self.model_cls(**kwargs) instance.pk = await self.database.execute(expr) return instance diff --git a/tests/test_columns.py b/tests/test_columns.py index 2df8e78..28c8113 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -64,9 +64,9 @@ async def rollback_transactions(): async def test_model_crud(): - await Product.objects.create() + product = await Product.objects.create() - product = await Product.objects.get() + product = await Product.objects.get(pk=product.pk) assert product.created.year == datetime.datetime.now().year assert product.created_day == datetime.date.today() assert product.data == {} @@ -92,6 +92,10 @@ async def test_model_crud(): assert product.price == decimal.Decimal("999.99") assert product.uuid == uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b") - await User.objects.create(name="Chris") - user = await User.objects.get(name="Chris") - assert user.name == "Chris" + +@pytest.mark.skipif(database.url.dialect == "sqlite", reason="Not supported on SQLite") +async def test_model_crud_with_non_integer_pk(): + user = await User.objects.create(name="Chris") + + assert isinstance(user.pk, uuid.UUID) + assert await User.objects.get(pk=user.pk) == user From 211c47d8ac6a5999f74b88f9cef2c26432d28ccd Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Fri, 1 Oct 2021 11:21:58 +0330 Subject: [PATCH 2/3] skip for mysql --- orm/models.py | 2 +- tests/test_columns.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/orm/models.py b/orm/models.py index 77b117e..092b9c0 100644 --- a/orm/models.py +++ b/orm/models.py @@ -392,7 +392,7 @@ async def create(self, **kwargs): if value.validator.read_only and value.validator.has_default(): kwargs[key] = value.validator.get_default_value() - if self.model_cls.database.url.dialect == "sqlite": + if self.model_cls.database.url.dialect in ["mysql", "sqlite"]: expr = self.table.insert().values(**kwargs) else: pk_column = getattr(self.table.c, self.pkname) diff --git a/tests/test_columns.py b/tests/test_columns.py index 28c8113..227f689 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -93,8 +93,10 @@ async def test_model_crud(): assert product.uuid == uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b") -@pytest.mark.skipif(database.url.dialect == "sqlite", reason="Not supported on SQLite") async def test_model_crud_with_non_integer_pk(): + if database.url.dialect in ["mysql", "sqlite"]: + pytest.skip("RETURNING clause not supported.") + user = await User.objects.create(name="Chris") assert isinstance(user.pk, uuid.UUID) From f4cf732d6861c07130f70258c375de149f2fcd20 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Sat, 2 Oct 2021 09:24:30 +0330 Subject: [PATCH 3/3] update create method --- orm/models.py | 12 ++++++------ tests/test_columns.py | 6 ------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/orm/models.py b/orm/models.py index 092b9c0..4c75c0a 100644 --- a/orm/models.py +++ b/orm/models.py @@ -392,14 +392,14 @@ async def create(self, **kwargs): if value.validator.read_only and value.validator.has_default(): kwargs[key] = value.validator.get_default_value() - if self.model_cls.database.url.dialect in ["mysql", "sqlite"]: - expr = self.table.insert().values(**kwargs) + instance = self.model_cls(**kwargs) + expr = self.table.insert().values(**kwargs) + + if self.pkname not in kwargs: + instance.pk = await self.database.execute(expr) else: - pk_column = getattr(self.table.c, self.pkname) - expr = self.table.insert().values(**kwargs).returning(pk_column) + await self.database.execute(expr) - instance = self.model_cls(**kwargs) - instance.pk = await self.database.execute(expr) return instance async def delete(self) -> None: diff --git a/tests/test_columns.py b/tests/test_columns.py index 227f689..e3c14b7 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -92,12 +92,6 @@ async def test_model_crud(): assert product.price == decimal.Decimal("999.99") assert product.uuid == uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b") - -async def test_model_crud_with_non_integer_pk(): - if database.url.dialect in ["mysql", "sqlite"]: - pytest.skip("RETURNING clause not supported.") - user = await User.objects.create(name="Chris") - assert isinstance(user.pk, uuid.UUID) assert await User.objects.get(pk=user.pk) == user