Automatically use a function at insert or selectΒΆ

Sometimes the application wants to apply a function in an insert or in a select. For example, the application might need the geometry with lat/lon coordinates while they are projected in the DB. To avoid having to always tweak the query with a ST_Transform(), it is possible to define a TypeDecorator

 11 import re
 12 from typing import Any
 13
 14 from sqlalchemy import Column
 15 from sqlalchemy import Integer
 16 from sqlalchemy import MetaData
 17 from sqlalchemy import func
 18 from sqlalchemy import text
 19 from sqlalchemy.orm import declarative_base
 20 from sqlalchemy.types import TypeDecorator
 21
 22 from geoalchemy2 import Geometry
 23 from geoalchemy2 import shape
 24
 25 # Tests imports
 26 from tests import test_only_with_dialects
 27
 28 metadata = MetaData()
 29
 30 Base = declarative_base(metadata=metadata)
 31
 32
 33 class TransformedGeometry(TypeDecorator):
 34     """This class is used to insert a ST_Transform() in each insert or select."""
 35
 36     impl = Geometry
 37
 38     cache_ok = True
 39
 40     def __init__(self, db_srid, app_srid, **kwargs):
 41         kwargs["srid"] = db_srid
 42         super().__init__(**kwargs)
 43         self.app_srid = app_srid
 44         self.db_srid = db_srid
 45
 46     def column_expression(self, col):
 47         """The column_expression() method is overridden to set the correct type.
 48
 49         This is needed so that the returned element will also be decorated. In this case we don't
 50         want to transform it again afterwards so we set the same SRID to both the ``db_srid`` and
 51         ``app_srid`` arguments.
 52         Without this the SRID of the WKBElement would be wrong.
 53         """
 54         return getattr(func, self.impl.as_binary)(
 55             func.ST_Transform(col, self.app_srid),
 56             type_=self.__class__(db_srid=self.app_srid, app_srid=self.app_srid),
 57         )
 58
 59     def bind_expression(self, bindvalue):
 60         return func.ST_Transform(
 61             self.impl.bind_expression(bindvalue),
 62             self.db_srid,
 63             type_=self,
 64         )
 65
 66
 67 class ThreeDGeometry(TypeDecorator):
 68     """This class is used to insert a ST_Force3D() in each insert."""
 69
 70     impl = Geometry
 71
 72     cache_ok = True
 73
 74     def column_expression(self, col):
 75         """The column_expression() method is overridden to set the correct type.
 76
 77         This is not needed in this example but it is needed if one wants to override other methods
 78         of the TypeDecorator class, like ``process_result_value()`` for example.
 79         """
 80         return getattr(func, self.impl.as_binary)(col, type_=self)
 81
 82     def bind_expression(self, bindvalue):
 83         return func.ST_Force3D(
 84             self.impl.bind_expression(bindvalue),
 85             type=self,
 86         )
 87
 88
 89 class Point(Base):  # type: ignore
 90     __tablename__ = "point"
 91     id = Column(Integer, primary_key=True)
 92     raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
 93     geom: Column[Any] = Column(
 94         TransformedGeometry(db_srid=2154, app_srid=4326, geometry_type="POINT")
 95     )
 96     three_d_geom: Column = Column(ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
 97
 98
 99 def check_wkb(wkb, x, y):
100     pt = shape.to_shape(wkb)
101     assert round(pt.x, 5) == x
102     assert round(pt.y, 5) == y
103
104
105 @test_only_with_dialects("postgresql")
106 class TestTypeDecorator:
107     def _create_one_point(self, session, conn):
108         metadata.drop_all(conn, checkfirst=True)
109         metadata.create_all(conn)
110
111         # Create new point instance
112         p = Point()
113         p.raw_geom = "SRID=4326;POINT(5 45)"
114         p.geom = "SRID=4326;POINT(5 45)"
115         p.three_d_geom = "SRID=4326;POINT(5 45)"  # Insert 2D geometry into 3D column
116
117         # Insert point
118         session.add(p)
119         session.flush()
120         session.expire(p)
121
122         return p.id
123
124     def test_transform(self, session, conn):
125         self._create_one_point(session, conn)
126
127         # Query the point and check the result
128         pt = session.query(Point).one()
129         assert pt.id == 1
130         assert pt.raw_geom.srid == 4326
131         check_wkb(pt.raw_geom, 5, 45)
132
133         assert pt.geom.srid == 4326
134         check_wkb(pt.geom, 5, 45)
135
136         # Check that the data is correct in DB using raw query
137         q = text("SELECT id, ST_AsEWKT(geom) AS geom FROM point;")
138         res_q = session.execute(q).fetchone()
139         assert res_q.id == 1
140         assert re.match(
141             r"SRID=2154;POINT\(857581\.8993196681? 6435414\.7478354[0-9]*\)", res_q.geom
142         )
143
144         # Compare geom, raw_geom with auto transform and explicit transform
145         pt_trans = session.query(
146             Point,
147             Point.raw_geom,
148             func.ST_Transform(Point.raw_geom, 2154).label("trans"),
149         ).one()
150
151         assert pt_trans[0].id == 1
152
153         assert pt_trans[0].geom.srid == 4326
154         check_wkb(pt_trans[0].geom, 5, 45)
155
156         assert pt_trans[0].raw_geom.srid == 4326
157         check_wkb(pt_trans[0].raw_geom, 5, 45)
158
159         assert pt_trans[1].srid == 4326
160         check_wkb(pt_trans[1], 5, 45)
161
162         assert pt_trans[2].srid == 2154
163         check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
164
165     def test_force_3d(self, session, conn):
166         self._create_one_point(session, conn)
167
168         # Query the point and check the result
169         pt = session.query(Point).one()
170
171         assert pt.id == 1
172         assert pt.three_d_geom.srid == 4326
173         assert pt.three_d_geom.desc.lower() == (
174             "01010000a0e6100000000000000000144000000000008046400000000000000000"
175         )

Gallery generated by Sphinx-Gallery