1 | import { Context } from './context'
|
2 | import { ExclusiveKeys, MaybePromise } from './core/helpers/util'
|
3 | import { MiddlewareFn } from './middleware'
|
4 | import d from 'debug'
|
5 | const debug = d('telegraf:session')
|
6 |
|
7 | export interface SyncSessionStore<T> {
|
8 | get: (name: string) => T | undefined
|
9 | set: (name: string, value: T) => void
|
10 | delete: (name: string) => void
|
11 | }
|
12 |
|
13 | export interface AsyncSessionStore<T> {
|
14 | get: (name: string) => Promise<T | undefined>
|
15 | set: (name: string, value: T) => Promise<unknown>
|
16 | delete: (name: string) => Promise<unknown>
|
17 | }
|
18 |
|
19 | export type SessionStore<T> = SyncSessionStore<T> | AsyncSessionStore<T>
|
20 |
|
21 | interface SessionOptions<S, C extends Context, P extends string> {
|
22 |
|
23 | property?: P
|
24 | getSessionKey?: (ctx: C) => MaybePromise<string | undefined>
|
25 | store?: SessionStore<S>
|
26 | defaultSession?: (ctx: C) => S
|
27 | }
|
28 |
|
29 |
|
30 | export interface SessionContext<S extends object> extends Context {
|
31 | session?: S
|
32 | }
|
33 |
|
34 |
|
35 |
|
36 |
|
37 |
|
38 |
|
39 |
|
40 |
|
41 |
|
42 |
|
43 |
|
44 |
|
45 |
|
46 |
|
47 |
|
48 | export function session<
|
49 | S extends NonNullable<C[P]>,
|
50 | C extends Context & { [key in P]?: C[P] },
|
51 | P extends (ExclusiveKeys<C, Context> & string) | 'session' = 'session',
|
52 |
|
53 |
|
54 | >(options?: SessionOptions<S, C, P>): MiddlewareFn<C> {
|
55 | const prop = options?.property ?? ('session' as P)
|
56 | const getSessionKey = options?.getSessionKey ?? defaultGetSessionKey
|
57 | const store = options?.store ?? new MemorySessionStore()
|
58 |
|
59 |
|
60 | const cache = new Map<string, { ref?: S; counter: number }>()
|
61 |
|
62 | const concurrents = new Map<string, MaybePromise<S | undefined>>()
|
63 |
|
64 |
|
65 |
|
66 |
|
67 | return async (ctx, next) => {
|
68 | const updId = ctx.update.update_id
|
69 |
|
70 |
|
71 |
|
72 | const key = await getSessionKey(ctx)
|
73 | if (!key) {
|
74 |
|
75 | ctx[prop] = undefined as unknown as S
|
76 | return await next()
|
77 | }
|
78 |
|
79 | let cached = cache.get(key)
|
80 | if (cached) {
|
81 | debug(`(${updId}) found cached session, reusing from cache`)
|
82 | ++cached.counter
|
83 | } else {
|
84 | debug(`(${updId}) did not find cached session`)
|
85 |
|
86 | let promise = concurrents.get(key)
|
87 | if (promise)
|
88 | debug(`(${updId}) found a concurrent request, reusing promise`)
|
89 | else {
|
90 | debug(`(${updId}) fetching from upstream store`)
|
91 | promise = store.get(key)
|
92 | }
|
93 |
|
94 | concurrents.set(key, promise)
|
95 | const upstream = await promise
|
96 |
|
97 | concurrents.delete(key)
|
98 | debug(`(${updId}) updating cache`)
|
99 |
|
100 | const c = cache.get(key)
|
101 | if (c) {
|
102 |
|
103 | c.counter++
|
104 |
|
105 | cached = c
|
106 | } else {
|
107 |
|
108 | cached = { ref: upstream ?? options?.defaultSession?.(ctx), counter: 1 }
|
109 | cache.set(key, cached)
|
110 | }
|
111 | }
|
112 |
|
113 |
|
114 |
|
115 | const c = cached
|
116 |
|
117 | let touched = false
|
118 |
|
119 | Object.defineProperty(ctx, prop, {
|
120 | get() {
|
121 | touched = true
|
122 | return c.ref
|
123 | },
|
124 | set(value: S) {
|
125 | touched = true
|
126 | c.ref = value
|
127 | },
|
128 | })
|
129 |
|
130 | try {
|
131 | await next()
|
132 | } finally {
|
133 | if (--c.counter === 0) {
|
134 |
|
135 | debug(`(${updId}) refcounter reached 0, removing cached`)
|
136 | cache.delete(key)
|
137 | }
|
138 | debug(`(${updId}) middlewares completed, checking session`)
|
139 |
|
140 |
|
141 | if (touched)
|
142 | if (ctx[prop] == null) {
|
143 | debug(`(${updId}) ctx.${prop} missing, removing from store`)
|
144 | await store.delete(key)
|
145 | } else {
|
146 | debug(`(${updId}) ctx.${prop} found, updating store`)
|
147 | await store.set(key, ctx[prop] as S)
|
148 | }
|
149 | }
|
150 | }
|
151 | }
|
152 |
|
153 | async function defaultGetSessionKey(ctx: Context): Promise<string | undefined> {
|
154 | const fromId = ctx.from?.id
|
155 | const chatId = ctx.chat?.id
|
156 | if (fromId == null || chatId == null) {
|
157 | return undefined
|
158 | }
|
159 | return `${fromId}:${chatId}`
|
160 | }
|
161 |
|
162 |
|
163 | export class MemorySessionStore<T> implements SyncSessionStore<T> {
|
164 | private readonly store = new Map<string, { session: T; expires: number }>()
|
165 |
|
166 | constructor(private readonly ttl = Infinity) {}
|
167 |
|
168 | get(name: string): T | undefined {
|
169 | const entry = this.store.get(name)
|
170 | if (entry == null) {
|
171 | return undefined
|
172 | } else if (entry.expires < Date.now()) {
|
173 | this.delete(name)
|
174 | return undefined
|
175 | }
|
176 | return entry.session
|
177 | }
|
178 |
|
179 | set(name: string, value: T): void {
|
180 | const now = Date.now()
|
181 | this.store.set(name, { session: value, expires: now + this.ttl })
|
182 | }
|
183 |
|
184 | delete(name: string): void {
|
185 | this.store.delete(name)
|
186 | }
|
187 | }
|
188 |
|
189 |
|
190 | export function isSessionContext<S extends object>(
|
191 | ctx: Context
|
192 | ): ctx is SessionContext<S> {
|
193 | return 'session' in ctx
|
194 | }
|