1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
//! Allows you to take an existing request or stream of data and convert it into a
//! WebSocket client.
use header::extensions::Extension;
use header::{
	Origin, WebSocketAccept, WebSocketExtensions, WebSocketKey, WebSocketProtocol, WebSocketVersion,
};
use std::error::Error;
use std::fmt::{self, Display, Formatter};
use std::io;
use stream::Stream;

use hyper::header::{Connection, ConnectionOption, Headers, Protocol, ProtocolName, Upgrade};
use hyper::http::h1::Incoming;
use hyper::method::Method;
use hyper::status::StatusCode;
use hyper::uri::RequestUri;
use unicase::UniCase;

#[cfg(any(feature = "sync", feature = "async"))]
use hyper::version::HttpVersion;

#[cfg(feature = "async")]
pub mod async;

#[cfg(feature = "sync")]
pub mod sync;

/// A typical request from hyper
pub type Request = Incoming<(Method, RequestUri)>;

/// Intermediate representation of a half created websocket session.
/// Should be used to examine the client's handshake
/// accept the protocols requested, route the path, etc.
///
/// Users should then call `accept` or `reject` to complete the handshake
/// and start a session.
/// Note: if the stream in use is `AsyncRead + AsyncWrite`, then asynchronous
/// functions will be available when completing the handshake.
/// Otherwise if the stream is simply `Read + Write` blocking functions will be
/// available to complete the handshake.
pub struct WsUpgrade<S, B>
where
	S: Stream,
{
	/// The headers that will be used in the handshake response.
	pub headers: Headers,
	/// The stream that will be used to read from / write to.
	pub stream: S,
	/// The handshake request, filled with useful metadata.
	pub request: Request,
	/// Some buffered data from the stream, if it exists.
	pub buffer: B,
}

impl<S, B> WsUpgrade<S, B>
where
	S: Stream,
{
	/// Select a protocol to use in the handshake response.
	pub fn use_protocol<P>(mut self, protocol: P) -> Self
	where
		P: Into<String>,
	{
		upsert_header!(self.headers; WebSocketProtocol; {
			Some(protos) => protos.0.push(protocol.into()),
			None => WebSocketProtocol(vec![protocol.into()])
		});
		self
	}

	/// Select an extension to use in the handshake response.
	pub fn use_extension(mut self, extension: Extension) -> Self {
		upsert_header!(self.headers; WebSocketExtensions; {
			Some(protos) => protos.0.push(extension),
			None => WebSocketExtensions(vec![extension])
		});
		self
	}

	/// Select multiple extensions to use in the connection
	pub fn use_extensions<I>(mut self, extensions: I) -> Self
	where
		I: IntoIterator<Item = Extension>,
	{
		let mut extensions: Vec<Extension> = extensions.into_iter().collect();
		upsert_header!(self.headers; WebSocketExtensions; {
			Some(protos) => protos.0.append(&mut extensions),
			None => WebSocketExtensions(extensions)
		});
		self
	}

	/// Drop the connection without saying anything.
	pub fn drop(self) {
		::std::mem::drop(self);
	}

	/// A list of protocols requested from the client.
	pub fn protocols(&self) -> &[String] {
		self.request
			.headers
			.get::<WebSocketProtocol>()
			.map(|p| p.0.as_slice())
			.unwrap_or(&[])
	}

	/// A list of extensions requested from the client.
	pub fn extensions(&self) -> &[Extension] {
		self.request
			.headers
			.get::<WebSocketExtensions>()
			.map(|e| e.0.as_slice())
			.unwrap_or(&[])
	}

	/// The client's websocket accept key.
	pub fn key(&self) -> Option<&[u8; 16]> {
		self.request.headers.get::<WebSocketKey>().map(|k| &(k.0).0)
	}

	/// The client's websocket version.
	pub fn version(&self) -> Option<&WebSocketVersion> {
		self.request.headers.get::<WebSocketVersion>()
	}

	/// The original request URI.
	pub fn uri(&self) -> String {
		format!("{}", self.request.subject.1)
	}

	/// Origin of the client
	pub fn origin(&self) -> Option<&str> {
		self.request.headers.get::<Origin>().map(|o| &o.0 as &str)
	}

	#[cfg(feature = "sync")]
	fn send(&mut self, status: StatusCode) -> io::Result<()> {
		let data = format!(
			"{} {}\r\n{}\r\n",
			self.request.version, status, self.headers
		);
		self.stream.write_all(data.as_bytes())?;
		Ok(())
	}

	#[doc(hidden)]
	pub fn prepare_headers(&mut self, custom: Option<&Headers>) -> StatusCode {
		if let Some(headers) = custom {
			self.headers.extend(headers.iter());
		}
		// NOTE: we know there is a key because this is a valid request
		// i.e. to construct this you must go through the validate function
		let key = self.request.headers.get::<WebSocketKey>().unwrap();
		self.headers.set(WebSocketAccept::new(key));
		self.headers
			.set(Connection(vec![ConnectionOption::ConnectionHeader(
				UniCase("Upgrade".to_string()),
			)]));
		self.headers
			.set(Upgrade(vec![Protocol::new(ProtocolName::WebSocket, None)]));

		StatusCode::SwitchingProtocols
	}
}

/// Errors that can occur when one tries to upgrade a connection to a
/// websocket connection.
#[derive(Debug)]
pub enum HyperIntoWsError {
	/// The HTTP method in a valid websocket upgrade request must be GET
	MethodNotGet,
	/// Currently HTTP 2 is not supported
	UnsupportedHttpVersion,
	/// Currently only WebSocket13 is supported (RFC6455)
	UnsupportedWebsocketVersion,
	/// A websocket upgrade request must contain a key
	NoSecWsKeyHeader,
	/// A websocket upgrade request must ask to upgrade to a `websocket`
	NoWsUpgradeHeader,
	/// A websocket upgrade request must contain an `Upgrade` header
	NoUpgradeHeader,
	/// A websocket upgrade request's `Connection` header must be `Upgrade`
	NoWsConnectionHeader,
	/// A websocket upgrade request must contain a `Connection` header
	NoConnectionHeader,
	/// IO error from reading the underlying socket
	Io(io::Error),
	/// Error while parsing an incoming request
	Parsing(::hyper::error::Error),
}

impl Display for HyperIntoWsError {
	fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> {
		fmt.write_str(self.description())
	}
}

impl Error for HyperIntoWsError {
	fn description(&self) -> &str {
		use self::HyperIntoWsError::*;
		match *self {
			MethodNotGet => "Request method must be GET",
			UnsupportedHttpVersion => "Unsupported request HTTP version",
			UnsupportedWebsocketVersion => "Unsupported WebSocket version",
			NoSecWsKeyHeader => "Missing Sec-WebSocket-Key header",
			NoWsUpgradeHeader => "Invalid Upgrade WebSocket header",
			NoUpgradeHeader => "Missing Upgrade WebSocket header",
			NoWsConnectionHeader => "Invalid Connection WebSocket header",
			NoConnectionHeader => "Missing Connection WebSocket header",
			Io(ref e) => e.description(),
			Parsing(ref e) => e.description(),
		}
	}

	fn cause(&self) -> Option<&Error> {
		match *self {
			HyperIntoWsError::Io(ref e) => Some(e),
			HyperIntoWsError::Parsing(ref e) => Some(e),
			_ => None,
		}
	}
}

impl From<io::Error> for HyperIntoWsError {
	fn from(err: io::Error) -> Self {
		HyperIntoWsError::Io(err)
	}
}

impl From<::hyper::error::Error> for HyperIntoWsError {
	fn from(err: ::hyper::error::Error) -> Self {
		HyperIntoWsError::Parsing(err)
	}
}

#[cfg(feature = "async")]
impl From<::codec::http::HttpCodecError> for HyperIntoWsError {
	fn from(src: ::codec::http::HttpCodecError) -> Self {
		match src {
			::codec::http::HttpCodecError::Io(e) => HyperIntoWsError::Io(e),
			::codec::http::HttpCodecError::Http(e) => HyperIntoWsError::Parsing(e),
		}
	}
}

#[cfg(any(feature = "sync", feature = "async"))]
/// Check whether an incoming request is a valid WebSocket upgrade attempt.
pub fn validate(
	method: &Method,
	version: HttpVersion,
	headers: &Headers,
) -> Result<(), HyperIntoWsError> {
	if *method != Method::Get {
		return Err(HyperIntoWsError::MethodNotGet);
	}

	if version == HttpVersion::Http09 || version == HttpVersion::Http10 {
		return Err(HyperIntoWsError::UnsupportedHttpVersion);
	}

	if let Some(version) = headers.get::<WebSocketVersion>() {
		if version != &WebSocketVersion::WebSocket13 {
			return Err(HyperIntoWsError::UnsupportedWebsocketVersion);
		}
	}

	if headers.get::<WebSocketKey>().is_none() {
		return Err(HyperIntoWsError::NoSecWsKeyHeader);
	}

	match headers.get() {
		Some(&Upgrade(ref upgrade)) => {
			if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) {
				return Err(HyperIntoWsError::NoWsUpgradeHeader);
			}
		}
		None => return Err(HyperIntoWsError::NoUpgradeHeader),
	};

	fn check_connection_header(headers: &[ConnectionOption]) -> bool {
		for header in headers {
			if let ConnectionOption::ConnectionHeader(ref h) = *header {
				if UniCase(h as &str) == UniCase("upgrade") {
					return true;
				}
			}
		}
		false
	}

	match headers.get() {
		Some(&Connection(ref connection)) => {
			if !check_connection_header(connection) {
				return Err(HyperIntoWsError::NoWsConnectionHeader);
			}
		}
		None => return Err(HyperIntoWsError::NoConnectionHeader),
	};

	Ok(())
}