Note
Click here to download the full example code
Automatically use a function at insert or selectΒΆ
Sometimes the application wants to apply a function in an insert or in a select.
For example, the application might need the geometry with lat/lon coordinates while they
are projected in the DB. To avoid having to always tweak the query with a
ST_Transform()
, it is possible to define a TypeDecorator
11 import re
12 from typing import Any
13
14 from sqlalchemy import Column
15 from sqlalchemy import Integer
16 from sqlalchemy import MetaData
17 from sqlalchemy import func
18 from sqlalchemy import text
19 from sqlalchemy.orm import declarative_base
20 from sqlalchemy.types import TypeDecorator
21
22 from geoalchemy2 import Geometry
23 from geoalchemy2 import shape
24
25 # Tests imports
26 from tests import test_only_with_dialects
27
28 metadata = MetaData()
29
30 Base = declarative_base(metadata=metadata)
31
32
33 class TransformedGeometry(TypeDecorator):
34 """This class is used to insert a ST_Transform() in each insert or select."""
35
36 impl = Geometry
37
38 cache_ok = True
39
40 def __init__(self, db_srid, app_srid, **kwargs):
41 kwargs["srid"] = db_srid
42 super().__init__(**kwargs)
43 self.app_srid = app_srid
44 self.db_srid = db_srid
45
46 def column_expression(self, col):
47 """The column_expression() method is overridden to set the correct type.
48
49 This is needed so that the returned element will also be decorated. In this case we don't
50 want to transform it again afterwards so we set the same SRID to both the ``db_srid`` and
51 ``app_srid`` arguments.
52 Without this the SRID of the WKBElement would be wrong.
53 """
54 return getattr(func, self.impl.as_binary)(
55 func.ST_Transform(col, self.app_srid),
56 type_=self.__class__(db_srid=self.app_srid, app_srid=self.app_srid),
57 )
58
59 def bind_expression(self, bindvalue):
60 return func.ST_Transform(
61 self.impl.bind_expression(bindvalue),
62 self.db_srid,
63 type_=self,
64 )
65
66
67 class ThreeDGeometry(TypeDecorator):
68 """This class is used to insert a ST_Force3D() in each insert."""
69
70 impl = Geometry
71
72 cache_ok = True
73
74 def column_expression(self, col):
75 """The column_expression() method is overridden to set the correct type.
76
77 This is not needed in this example but it is needed if one wants to override other methods
78 of the TypeDecorator class, like ``process_result_value()`` for example.
79 """
80 return getattr(func, self.impl.as_binary)(col, type_=self)
81
82 def bind_expression(self, bindvalue):
83 return func.ST_Force3D(
84 self.impl.bind_expression(bindvalue),
85 type=self,
86 )
87
88
89 class Point(Base): # type: ignore
90 __tablename__ = "point"
91 id = Column(Integer, primary_key=True)
92 raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
93 geom: Column[Any] = Column(
94 TransformedGeometry(db_srid=2154, app_srid=4326, geometry_type="POINT")
95 )
96 three_d_geom: Column = Column(ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
97
98
99 def check_wkb(wkb, x, y):
100 pt = shape.to_shape(wkb)
101 assert round(pt.x, 5) == x
102 assert round(pt.y, 5) == y
103
104
105 @test_only_with_dialects("postgresql")
106 class TestTypeDecorator:
107 def _create_one_point(self, session, conn):
108 metadata.drop_all(conn, checkfirst=True)
109 metadata.create_all(conn)
110
111 # Create new point instance
112 p = Point()
113 p.raw_geom = "SRID=4326;POINT(5 45)"
114 p.geom = "SRID=4326;POINT(5 45)"
115 p.three_d_geom = "SRID=4326;POINT(5 45)" # Insert 2D geometry into 3D column
116
117 # Insert point
118 session.add(p)
119 session.flush()
120 session.expire(p)
121
122 return p.id
123
124 def test_transform(self, session, conn):
125 self._create_one_point(session, conn)
126
127 # Query the point and check the result
128 pt = session.query(Point).one()
129 assert pt.id == 1
130 assert pt.raw_geom.srid == 4326
131 check_wkb(pt.raw_geom, 5, 45)
132
133 assert pt.geom.srid == 4326
134 check_wkb(pt.geom, 5, 45)
135
136 # Check that the data is correct in DB using raw query
137 q = text("SELECT id, ST_AsEWKT(geom) AS geom FROM point;")
138 res_q = session.execute(q).fetchone()
139 assert res_q.id == 1
140 assert re.match(
141 r"SRID=2154;POINT\(857581\.8993196681? 6435414\.7478354[0-9]*\)", res_q.geom
142 )
143
144 # Compare geom, raw_geom with auto transform and explicit transform
145 pt_trans = session.query(
146 Point,
147 Point.raw_geom,
148 func.ST_Transform(Point.raw_geom, 2154).label("trans"),
149 ).one()
150
151 assert pt_trans[0].id == 1
152
153 assert pt_trans[0].geom.srid == 4326
154 check_wkb(pt_trans[0].geom, 5, 45)
155
156 assert pt_trans[0].raw_geom.srid == 4326
157 check_wkb(pt_trans[0].raw_geom, 5, 45)
158
159 assert pt_trans[1].srid == 4326
160 check_wkb(pt_trans[1], 5, 45)
161
162 assert pt_trans[2].srid == 2154
163 check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
164
165 def test_force_3d(self, session, conn):
166 self._create_one_point(session, conn)
167
168 # Query the point and check the result
169 pt = session.query(Point).one()
170
171 assert pt.id == 1
172 assert pt.three_d_geom.srid == 4326
173 assert pt.three_d_geom.desc.lower() == (
174 "01010000a0e6100000000000000000144000000000008046400000000000000000"
175 )