/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ @file:Suppress("DEPRECATION_ERROR") // Conflicting okhttp versions package com.facebook.react.modules.websocket import com.facebook.common.logging.FLog import com.facebook.fbreact.specs.NativeWebSocketModuleSpec import com.facebook.react.bridge.Arguments import com.facebook.react.bridge.ReactApplicationContext import com.facebook.react.bridge.ReadableArray import com.facebook.react.bridge.ReadableMap import com.facebook.react.bridge.ReadableType import com.facebook.react.bridge.WritableMap import com.facebook.react.bridge.buildReadableMap import com.facebook.react.common.ReactConstants import com.facebook.react.module.annotations.ReactModule import com.facebook.react.modules.network.CustomClientBuilder import com.facebook.react.modules.network.ForwardingCookieHandler import java.io.IOException import java.net.URI import java.net.URISyntaxException import java.util.HashMap import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.TimeUnit import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.Response import okhttp3.WebSocket import okhttp3.WebSocketListener import okio.ByteString @ReactModule(name = WebSocketModule.NAME) public class WebSocketModule(context: ReactApplicationContext) : NativeWebSocketModuleSpec(context) { public interface ContentHandler { public fun onMessage(text: String, params: WritableMap) public fun onMessage(byteString: ByteString, params: WritableMap) } private val webSocketConnections: MutableMap = ConcurrentHashMap() private val contentHandlers: MutableMap = ConcurrentHashMap() private val cookieHandler: ForwardingCookieHandler = ForwardingCookieHandler() override fun invalidate() { for (socket in webSocketConnections.values) { socket.close(1_001 /* endpoint is going away */, null) } webSocketConnections.clear() contentHandlers.clear() } private fun sendEvent(eventName: String, params: ReadableMap) { val reactAppContext = reactApplicationContext if (reactAppContext.hasActiveReactInstance()) { reactAppContext.emitDeviceEvent(eventName, params) } } public fun setContentHandler(id: Int, contentHandler: ContentHandler?) { if (contentHandler != null) { contentHandlers[id] = contentHandler } else { contentHandlers.remove(id) } } override fun connect( url: String, protocols: ReadableArray?, options: ReadableMap?, socketID: Double, ) { val id = socketID.toInt() val okHttpBuilder = OkHttpClient.Builder() .connectTimeout(10, TimeUnit.SECONDS) .writeTimeout(10, TimeUnit.SECONDS) .readTimeout(0, TimeUnit.MINUTES) // Disable timeouts for read applyCustomBuilder(okHttpBuilder) val client = okHttpBuilder.build() val builder = Request.Builder().tag(id).url(url) val cookie = this.getCookie(url) if (cookie != null) { builder.addHeader("Cookie", cookie) } var hasOriginHeader = false if (options?.hasKey("headers") == true && options.getType("headers") == ReadableType.Map) { val headers = checkNotNull(options.getMap("headers")) val iterator = headers.keySetIterator() while (iterator.hasNextKey()) { val key = iterator.nextKey() if (ReadableType.String == headers.getType(key)) { if (key.equals("origin", ignoreCase = true)) { hasOriginHeader = true } builder.addHeader( key, checkNotNull(headers.getString(key)) { "value for name $key == null" }, ) } else { FLog.w(ReactConstants.TAG, "Ignoring: requested $key, value not a string") } } } if (!hasOriginHeader) { builder.addHeader("origin", getDefaultOrigin(url)) } if (protocols != null && protocols.size() > 0) { val protocolsValue = StringBuilder("") for (i in 0..>()) val cookieList = cookieMap["Cookie"] if (cookieList.isNullOrEmpty()) { return null } return cookieList[0] } catch (e: URISyntaxException) { throw IllegalArgumentException("Unable to get cookie from $uri") } catch (e: IOException) { throw IllegalArgumentException("Unable to get cookie from $uri") } } override fun addListener(eventName: String): Unit = Unit override fun removeListeners(count: Double): Unit = Unit public companion object { public const val NAME: String = NativeWebSocketModuleSpec.NAME private var customClientBuilder: CustomClientBuilder? = null @JvmStatic public fun setCustomClientBuilder(ccb: CustomClientBuilder?) { customClientBuilder = ccb } private fun applyCustomBuilder(builder: OkHttpClient.Builder) { customClientBuilder?.apply(builder) } /** * Get the default HTTP(S) origin for a specific WebSocket URI * * @param uri * @return A string of the endpoint converted to HTTP protocol (http[s]://host[:port]) */ private fun getDefaultOrigin(uri: String): String { try { val requestURI = URI(uri) val scheme = when (requestURI.scheme) { "wss" -> "https" "ws" -> "http" "http", "https" -> requestURI.scheme else -> "" } val defaultOrigin = if (requestURI.port != -1) { String.format("%s://%s:%s", scheme, requestURI.host, requestURI.port) } else { String.format("%s://%s", scheme, requestURI.host) } return defaultOrigin } catch (e: URISyntaxException) { throw IllegalArgumentException("Unable to set $uri as default origin header") } } } }