1 | import type { Middleware } from 'redux'
|
2 | import { getTimeMeasureUtils } from './utils'
|
3 |
|
4 | type EntryProcessor = (key: string, value: any) => any
|
5 |
|
6 | const isProduction: boolean = process.env.NODE_ENV === 'production'
|
7 | const prefix: string = 'Invariant failed'
|
8 |
|
9 |
|
10 |
|
11 |
|
12 | function invariant(condition: any, message?: string) {
|
13 | if (condition) {
|
14 | return
|
15 | }
|
16 |
|
17 |
|
18 |
|
19 | if (isProduction) {
|
20 | throw new Error(prefix)
|
21 | }
|
22 |
|
23 |
|
24 |
|
25 | throw new Error(`${prefix}: ${message || ''}`)
|
26 | }
|
27 |
|
28 | function stringify(
|
29 | obj: any,
|
30 | serializer?: EntryProcessor,
|
31 | indent?: string | number,
|
32 | decycler?: EntryProcessor
|
33 | ): string {
|
34 | return JSON.stringify(obj, getSerialize(serializer, decycler), indent)
|
35 | }
|
36 |
|
37 | function getSerialize(
|
38 | serializer?: EntryProcessor,
|
39 | decycler?: EntryProcessor
|
40 | ): EntryProcessor {
|
41 | let stack: any[] = [],
|
42 | keys: any[] = []
|
43 |
|
44 | if (!decycler)
|
45 | decycler = function (_: string, value: any) {
|
46 | if (stack[0] === value) return '[Circular ~]'
|
47 | return (
|
48 | '[Circular ~.' + keys.slice(0, stack.indexOf(value)).join('.') + ']'
|
49 | )
|
50 | }
|
51 |
|
52 | return function (this: any, key: string, value: any) {
|
53 | if (stack.length > 0) {
|
54 | var thisPos = stack.indexOf(this)
|
55 | ~thisPos ? stack.splice(thisPos + 1) : stack.push(this)
|
56 | ~thisPos ? keys.splice(thisPos, Infinity, key) : keys.push(key)
|
57 | if (~stack.indexOf(value)) value = decycler!.call(this, key, value)
|
58 | } else stack.push(value)
|
59 |
|
60 | return serializer == null ? value : serializer.call(this, key, value)
|
61 | }
|
62 | }
|
63 |
|
64 |
|
65 |
|
66 |
|
67 |
|
68 |
|
69 | export function isImmutableDefault(value: unknown): boolean {
|
70 | return (
|
71 | typeof value !== 'object' ||
|
72 | value === null ||
|
73 | typeof value === 'undefined' ||
|
74 | Object.isFrozen(value)
|
75 | )
|
76 | }
|
77 |
|
78 | export function trackForMutations(
|
79 | isImmutable: IsImmutableFunc,
|
80 | ignorePaths: string[] | undefined,
|
81 | obj: any
|
82 | ) {
|
83 | const trackedProperties = trackProperties(isImmutable, ignorePaths, obj)
|
84 | return {
|
85 | detectMutations() {
|
86 | return detectMutations(isImmutable, ignorePaths, trackedProperties, obj)
|
87 | },
|
88 | }
|
89 | }
|
90 |
|
91 | interface TrackedProperty {
|
92 | value: any
|
93 | children: Record<string, any>
|
94 | }
|
95 |
|
96 | function trackProperties(
|
97 | isImmutable: IsImmutableFunc,
|
98 | ignorePaths: IgnorePaths = [],
|
99 | obj: Record<string, any>,
|
100 | path: string = ''
|
101 | ) {
|
102 | const tracked: Partial<TrackedProperty> = { value: obj }
|
103 |
|
104 | if (!isImmutable(obj)) {
|
105 | tracked.children = {}
|
106 |
|
107 | for (const key in obj) {
|
108 | const childPath = path ? path + '.' + key : key
|
109 | if (ignorePaths.length && ignorePaths.indexOf(childPath) !== -1) {
|
110 | continue
|
111 | }
|
112 |
|
113 | tracked.children[key] = trackProperties(
|
114 | isImmutable,
|
115 | ignorePaths,
|
116 | obj[key],
|
117 | childPath
|
118 | )
|
119 | }
|
120 | }
|
121 | return tracked as TrackedProperty
|
122 | }
|
123 |
|
124 | type IgnorePaths = readonly string[]
|
125 |
|
126 | function detectMutations(
|
127 | isImmutable: IsImmutableFunc,
|
128 | ignorePaths: IgnorePaths = [],
|
129 | trackedProperty: TrackedProperty,
|
130 | obj: any,
|
131 | sameParentRef: boolean = false,
|
132 | path: string = ''
|
133 | ): { wasMutated: boolean; path?: string } {
|
134 | const prevObj = trackedProperty ? trackedProperty.value : undefined
|
135 |
|
136 | const sameRef = prevObj === obj
|
137 |
|
138 | if (sameParentRef && !sameRef && !Number.isNaN(obj)) {
|
139 | return { wasMutated: true, path }
|
140 | }
|
141 |
|
142 | if (isImmutable(prevObj) || isImmutable(obj)) {
|
143 | return { wasMutated: false }
|
144 | }
|
145 |
|
146 |
|
147 | const keysToDetect: Record<string, boolean> = {}
|
148 | for (let key in trackedProperty.children) {
|
149 | keysToDetect[key] = true
|
150 | }
|
151 | for (let key in obj) {
|
152 | keysToDetect[key] = true
|
153 | }
|
154 |
|
155 | for (let key in keysToDetect) {
|
156 | const childPath = path ? path + '.' + key : key
|
157 | if (ignorePaths.length && ignorePaths.indexOf(childPath) !== -1) {
|
158 | continue
|
159 | }
|
160 |
|
161 | const result = detectMutations(
|
162 | isImmutable,
|
163 | ignorePaths,
|
164 | trackedProperty.children[key],
|
165 | obj[key],
|
166 | sameRef,
|
167 | childPath
|
168 | )
|
169 |
|
170 | if (result.wasMutated) {
|
171 | return result
|
172 | }
|
173 | }
|
174 | return { wasMutated: false }
|
175 | }
|
176 |
|
177 | type IsImmutableFunc = (value: any) => boolean
|
178 |
|
179 |
|
180 |
|
181 |
|
182 |
|
183 |
|
184 | export interface ImmutableStateInvariantMiddlewareOptions {
|
185 | |
186 |
|
187 |
|
188 |
|
189 |
|
190 |
|
191 | isImmutable?: IsImmutableFunc
|
192 | |
193 |
|
194 |
|
195 |
|
196 |
|
197 | ignoredPaths?: string[]
|
198 |
|
199 | warnAfter?: number
|
200 |
|
201 | ignore?: string[]
|
202 | }
|
203 |
|
204 |
|
205 |
|
206 |
|
207 |
|
208 |
|
209 |
|
210 |
|
211 |
|
212 |
|
213 | export function createImmutableStateInvariantMiddleware(
|
214 | options: ImmutableStateInvariantMiddlewareOptions = {}
|
215 | ): Middleware {
|
216 | if (process.env.NODE_ENV === 'production') {
|
217 | return () => (next) => (action) => next(action)
|
218 | }
|
219 |
|
220 | let {
|
221 | isImmutable = isImmutableDefault,
|
222 | ignoredPaths,
|
223 | warnAfter = 32,
|
224 | ignore,
|
225 | } = options
|
226 |
|
227 |
|
228 | ignoredPaths = ignoredPaths || ignore
|
229 |
|
230 | const track = trackForMutations.bind(null, isImmutable, ignoredPaths)
|
231 |
|
232 | return ({ getState }) => {
|
233 | let state = getState()
|
234 | let tracker = track(state)
|
235 |
|
236 | let result
|
237 | return (next) => (action) => {
|
238 | const measureUtils = getTimeMeasureUtils(
|
239 | warnAfter,
|
240 | 'ImmutableStateInvariantMiddleware'
|
241 | )
|
242 |
|
243 | measureUtils.measureTime(() => {
|
244 | state = getState()
|
245 |
|
246 | result = tracker.detectMutations()
|
247 |
|
248 | tracker = track(state)
|
249 |
|
250 | invariant(
|
251 | !result.wasMutated,
|
252 | `A state mutation was detected between dispatches, in the path '${
|
253 | result.path || ''
|
254 | }'. This may cause incorrect behavior. (https://redux.js.org/style-guide/style-guide#do-not-mutate-state)`
|
255 | )
|
256 | })
|
257 |
|
258 | const dispatchedAction = next(action)
|
259 |
|
260 | measureUtils.measureTime(() => {
|
261 | state = getState()
|
262 |
|
263 | result = tracker.detectMutations()
|
264 |
|
265 | tracker = track(state)
|
266 |
|
267 | result.wasMutated &&
|
268 | invariant(
|
269 | !result.wasMutated,
|
270 | `A state mutation was detected inside a dispatch, in the path: ${
|
271 | result.path || ''
|
272 | }. Take a look at the reducer(s) handling the action ${stringify(
|
273 | action
|
274 | )}. (https://redux.js.org/style-guide/style-guide#do-not-mutate-state)`
|
275 | )
|
276 | })
|
277 |
|
278 | measureUtils.warnIfExceeded()
|
279 |
|
280 | return dispatchedAction
|
281 | }
|
282 | }
|
283 | }
|