// %BANNER_BEGIN% // --------------------------------------------------------------------- // %COPYRIGHT_BEGIN% // Copyright (c) 2022-2023 Magic Leap, Inc. All Rights Reserved. // Use of this file is governed by the Magic Leap 2 Software License Agreement, located here: https://www.magicleap.com/software-license-agreement-ml2 // Terms and conditions applicable to third-party materials accompanying this distribution may also be found in the top-level NOTICE file appearing herein. // %COPYRIGHT_END% // --------------------------------------------------------------------- // %BANNER_END% using System; using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Security; using System.Net.Sockets; using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; using UnityEngine; namespace MagicLeap.Spectator { public class TcpConnection { #region Private variables private string address = "NULL"; private int port = 0; private TcpClient socketConnection = null; private Task readLoop = null; private NetworkStream rawStream = null; private SslStream stream = null; private ArrayPool sendPool = ArrayPool.Create(); private ConcurrentDictionary>> actionMap = new(); #endregion #region Public events public event Action OnDisconnect = null; #endregion #region Public properties public SafeBool Connected { get; private set; } = new SafeBool(); // Was our connection purposefully closed public SafeBool Closed { get; private set; } = new SafeBool(); public string Address => address; public int Port => port; #endregion #region Constructors / Destructor public TcpConnection() { } ~TcpConnection() { Disconnect(); } #endregion #region Public methods public bool SetupConnection(TcpClient socketConnection, int receiveBufferSize, int sendBufferSize, byte[] sslServerCertificate = null) { if (Connected) { Debug.LogWarning("TcpConnection: SetupConnection() - Already connected!"); return false; } Closed.Set(false); this.socketConnection = socketConnection; socketConnection.NoDelay = true; rawStream = socketConnection.GetStream(); //rawStream.ReadTimeout = 1000; //rawStream.WriteTimeout = 1000; stream = new SslStream(rawStream, false, new RemoteCertificateValidationCallback(ValidateServerCertificate)); stream.WriteTimeout = 500; socketConnection.ReceiveBufferSize = receiveBufferSize; socketConnection.SendBufferSize = sendBufferSize; IPEndPoint ipEndPoint = socketConnection.Client.RemoteEndPoint as IPEndPoint; address = ipEndPoint.Address.ToString(); port = ipEndPoint.Port; if (sslServerCertificate != null) // is SSL host { try { X509Certificate certificate = new X509Certificate2(sslServerCertificate); if (!stream.AuthenticateAsServerAsync(certificate).Wait(1000)) return false; } catch (Exception ex) { Debug.LogError($"TcpConnection: SslStream: AuthenticateAsServer() - {ex.Message} --- {ex.InnerException.Message}"); return false; } } else { try { if (!stream.AuthenticateAsClientAsync(address).Wait(1000)) return false; } catch (Exception ex) { Debug.LogError($"TcpConnection: SslStream: AuthenticateAsClient() - {ex.Message} --- {ex.InnerException.Message}"); return false; } } Connected.Set(true); readLoop = new Task(ReadLoop, TaskCreationOptions.LongRunning); readLoop.Start(); return true; } public static bool ValidateServerCertificate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { if (certificate.Subject == "CN=SPECTATOR_CERTIFICATE" && certificate.GetCertHashString() == "FDA340DB2A537FA1DDC44BE723D23DCB6DF3755E") return true; else { Debug.LogError("TcpConnection: ValidateServerCertificate() - Validation of server certificate failed"); return false; } } public void Disconnect() { if (Connected) { lock (stream) stream?.Write(BitConverter.GetBytes(-1)); //stream.Flush(); //stream.Dispose(); Closed.Set(true); Connected.Set(false); } } public bool RegisterCallback(Guid guid, Action action) { // If we don't already have a hashset for our guid if (!actionMap.TryGetValue(guid, out HashSet> set)) { // Attempt to make one and add it to our dictionary set = new HashSet>(); if (!actionMap.TryAdd(guid, set)) return false; } // Add our action to our set return set.Add(action); } public bool UnregisterCallback(Guid guid, Action action) { // If we have a hashset for our guid if (actionMap.TryGetValue(guid, out HashSet> set)) { // Remove our action from our set return set.Remove(action); } // Otherwise, we never registered this callback return false; } public void Write(Message message) { if (!Connected) return; try { var buffer = sendPool.Rent(message.size + 20); Buffer.BlockCopy(BitConverter.GetBytes(message.size), 0, buffer, 0, 4); Buffer.BlockCopy(message.guid.ToByteArray(), 0, buffer, 4, 16); if (message.data != null && message.size > 0) Buffer.BlockCopy(message.data, 0, buffer, 20, message.size); lock (stream) { // Seem to need to do this for Android (phone) for some reason if (!stream.WriteAsync(buffer, 0, message.size + 20).Wait(500)) throw new Exception("Wait timout"); } sendPool.Return(buffer); } catch (SocketException) { Connected.Set(false); } catch (IOException) { Connected.Set(false); } catch (Exception ex) { Debug.LogError($"TcpConnection: Write() - {ex.Message}"); Connected.Set(false); } } #endregion #region Private methods private void ReadLoop() { try { while (Connected) { // Read a message from our stream Message message = new Message(stream, Connected, Closed); // If our message has any actions associated with it, perform them all if (actionMap.TryGetValue(message.guid, out HashSet> set)) foreach (var action in set) action?.Invoke(message); // Dispose of our message once we're done message.Dispose(); } } catch (SocketException) { Connected.Set(false); } catch (IOException) { Connected.Set(false); } catch (Exception ex) { Debug.LogError($"TcpConnection: ReadLoop() - {ex.Message}"); Closed.Set(false); Connected.Set(false); } if (stream != null) { stream.Dispose(); stream = null; } if (rawStream != null) { rawStream.Dispose(); rawStream = null; } if (socketConnection != null) { socketConnection.Dispose(); socketConnection = null; } OnDisconnect?.Invoke(); } #endregion } }