1 module passwd.util;
2 
3 @safe:
4 
5 import std.range.primitives;
6 
7 import passwd.exception;
8 public import passwd.securewipe;
9 
10 /// Fill buf with random bytes of cryptographic quality
11 void fillSecureRandom(ubyte[] buf) @nogc nothrow @trusted
12 {
13 	arc4random_buf(buf.ptr, buf.length);
14 }
15 
16 /// Parses crypt(3) output or salt in Modular Crypt Format (MCF)
17 const(CryptPieces) cryptSplit(const(char)[] crypt) pure
18 {
19 	import std.algorithm.iteration : splitter;
20 	import std.algorithm.searching : canFind, startsWith;
21 	import std.utf : byCodeUnit;
22 
23 	auto crypt_c = crypt.byCodeUnit;
24 	enforce!ValueException(crypt_c.startsWith("$"), "Hashed password must start with $");
25 	auto pieces = crypt_c.splitter('$');
26 	auto num_pieces = pieces.walkLength;
27 	enforce!ValueException(num_pieces >= 3 && num_pieces <= 5, "Expected 2 to 4 $ characters in hashed password");
28 
29 	assert (pieces.front.empty);
30 	pieces.popFront();
31 
32 	auto algo_id = pieces.front.source;
33 	pieces.popFront();
34 
35 	const(char)[] params;
36 	// parameters are optional, but they always contain an '=', and the salt never does because it uses the B64 characters
37 	if (pieces.front.canFind('='))
38 	{
39 		params = pieces.front.source;
40 		pieces.popFront();
41 		enforce!ValueException(!pieces.empty, "Missing salt (or invalid salt containing = character)");
42 	}
43 
44 	auto salt_txt = pieces.front.source;
45 	pieces.popFront();
46 
47 	const(char)[] digest_txt;
48 	if (!pieces.empty)
49 	{
50 		digest_txt = pieces.front.source;
51 		pieces.popFront();
52 	}
53 
54 	/*
55 		bcrypt's salt strings deviate a little from the usual MCF.
56 
57 		There's no "param_name=" for the log_rounds param, and the B64-encoded digest gets concatenated directly to the B64 salt without a '$' separator.
58 
59 		That means the parameter is currently in salt_txt, and the B64 salt has to be split from the B64 digest by length.
60 	*/
61 	if (algo_id.startsWith("2"))
62 	{
63 		params = salt_txt;
64 		// Valid example looks like $2b$04$WNiYqMnuLlK9V11NmAKCNeG4nDdfI2Uqvo1MTvCehk2D4F4FSbICy
65 		enforce!ValueException(digest_txt.length == 22 || digest_txt.length == 53, "Invalid bcrypt digest length");
66 		salt_txt = digest_txt[0..22];
67 		digest_txt = digest_txt[22..$];
68 	}
69 
70 	enforce!ValueException(pieces.empty, "Trailing data or corrupted format of hashed password");
71 	return const(CryptPieces)(algo_id, params, salt_txt, digest_txt);
72 }
73 
74 ///
75 unittest
76 {
77 	const result = cryptSplit("$5$rounds=10000$saltstringsaltst$3xv.VbSHBb41AL9AvLeujZkZRBAwqFMz2.opqey6IcA");
78 	assert (result.algo_id == "5");
79 	assert (result.params == "rounds=10000");
80 	assert (result.salt_txt == "saltstringsaltst");
81 	assert (result.digest_txt == "3xv.VbSHBb41AL9AvLeujZkZRBAwqFMz2.opqey6IcA");
82 }
83 
84 unittest
85 {
86 	assertThrown!ValueException(cryptSplit(""));
87 	assertThrown!ValueException(cryptSplit("$"));
88 	assertThrown!ValueException(cryptSplit("$$$$$$$$"));
89 	assertThrown!ValueException(cryptSplit("1$rounds=10$salt$"));
90 	cryptSplit("$5$salt");
91 	assertThrown!ValueException(cryptSplit("$5$salt$xyz$extra"));
92 	assertThrown!ValueException(cryptSplit("$5$rounds=10"));
93 	assertThrown!ValueException(cryptSplit("$5$salt$rounds=10$xyz"));
94 }
95 
96 /// Result of parsing MCF data
97 struct CryptPieces
98 {
99 	char[] algo_id;  /// Standard ID for hashing algorithm
100 	char[] params;  /// Extra parameters for hashing algorithm (may be empty)
101 	char[] salt_txt;  /// Plain text salt string
102 	char[] digest_txt;  /// Plain text result of hashing algorithm
103 }
104 
105 /**
106 	Encode data using crypt(3) base 64
107 
108 	Note: This is *not* the same base 64 as used in many internet standards.
109 */
110 string cryptB64Encode(const(ubyte)[] data) pure
111 {
112 	import std.array : appender;
113 	auto ret_app = appender!string;
114 	data.cryptB64Encode(ret_app);
115 	return ret_app.data;
116 }
117 
118 unittest
119 {
120 	import std..string : representation;
121 	assert (cryptB64Encode([]) == "");
122 	assert (cryptB64Encode([0]) == "..");
123 	assert (cryptB64Encode([1]) == "/.");
124 	assert (cryptB64Encode("asdfqwer".representation) == "VB5Na3rRZ75");
125 	assert (cryptB64Encode([0xff]) == "z1");
126 	assert (cryptB64Encode([0xff, 0xff]) == "zzD");
127 	assert (cryptB64Encode([0xff, 0xff, 0xff]) == "zzzz");
128 	assert (cryptB64Encode([0xff, 0xff, 0xff, 0xff]) == "zzzzz1");
129 }
130 
131 /**
132 	Encode data using crypt(3) base 64 to an output range
133 
134 	Note: This is *not* the same base 64 as used in many internet standards.
135 */
136 void cryptB64Encode(Out)(const(ubyte)[] data, ref Out output) if (isOutputRange!(Out, char))
137 {
138 	import std.range : chunks, retro;
139 	foreach (chunk; data.chunks(3))
140 	{
141 		uint v = 0;
142 		foreach (b; chunk.retro)
143 		{
144 			v <<= 8;
145 			v |= b;
146 		}
147 
148 		// l + 1 == ceil(l * 8.0 / 6) for 1 <= l <= 3
149 		assert (1 <= chunk.length && chunk.length <= 3);
150 		v.cryptB64Chars(output, chunk.length+1);
151 	}
152 }
153 
154 @nogc
155 unittest
156 {
157 	import std.utf : byCodeUnit;
158 	char[2] buf;
159 	auto buf_p = buf[].byCodeUnit;
160 	ubyte[1] data = [42];
161 	cryptB64Encode(data[], buf_p);
162 	assert (buf[0] == 'e');
163 	assert (buf[1] == '.');
164 }
165 
166 // The base 64 used in MIME and other internet standards uses a different set of characters and has a padding scheme
167 
168 /**
169 	Decode crypt(3) base 64 to an output range
170 
171 	Note: This is *not* the same base 64 as used in many internet standards.
172 */
173 void cryptB64Decode(Out)(const(char)[] data, ref Out output) if (isOutputRange!(Out, ubyte))
174 {
175 	int bits_count = 0;
176 	uint v = 0;
177 	foreach (char c; data)
178 	{
179 		v |= cryptB64DecodeChar(c) << bits_count;
180 		bits_count += 6;
181 		if (bits_count >= 8)
182 		{
183 			output.put(cast(ubyte)(v & 0xff));
184 			v >>= 8;
185 			bits_count -= 8;
186 		}
187 	}
188 
189 	while (bits_count > 0)
190 	{
191 		output.put(cast(ubyte)(v & 0xff));
192 		v >>= 8;
193 		bits_count -= 8;
194 	}
195 	assert (v == 0);
196 }
197 
198 unittest
199 {
200 	import std.array : appender;
201 	auto decoded_app = appender!(ubyte[]);
202 	assertThrown!ValueException("!!!".cryptB64Decode(decoded_app));
203 	"e.".cryptB64Decode(decoded_app);
204 	assert (decoded_app[] == [42, 0]);
205 }
206 
207 unittest
208 {
209 	bool testRoundTrip(const(ubyte)[] d)
210 	{
211 		import std.array : appender;
212 		import std..string : chomp;
213 		import std.algorithm.searching : endsWith;
214 		// Because B64 is 6b and ubytes are 8b and there's no padding, the round trip can add some extra 0 values to the ends
215 		auto encoded_app = appender!(char[]);
216 		d.cryptB64Encode(encoded_app);
217 		auto encoded = encoded_app[].chomp(".");
218 		auto decoded_app = appender!(ubyte[]);
219 		encoded.cryptB64Decode(decoded_app);
220 		auto decoded = decoded_app[];
221 		if (decoded.endsWith([0])) decoded = decoded[0..$-1];
222 		return d == decoded;
223 	}
224 
225 	import std..string : representation;
226 	assert (testRoundTrip([]));
227 	assert (testRoundTrip("0".representation));
228 	assert (testRoundTrip("01".representation));
229 	assert (testRoundTrip("012".representation));
230 	assert (testRoundTrip("0123".representation));
231 	assert (testRoundTrip("01234".representation));
232 	assert (testRoundTrip("012345".representation));
233 	assert (testRoundTrip("0123456".representation));
234 	assert (testRoundTrip("01234567".representation));
235 	assert (testRoundTrip("012345678".representation));
236 	assert (testRoundTrip("0123456789".representation));
237 }
238 
239 package:
240 
241 /// Write `length` bytes of `data` to an output range, repeating `data` as necessary
242 void stretchPut(Out)(ref Out output, const(ubyte)[] data, size_t length) if (isOutputRange!(Out, ubyte))
243 {
244 	import std.algorithm.iteration : joiner;
245 	import std.range : repeat, take;
246 	foreach (x; repeat(data).joiner.take(length)) output.put(x);
247 }
248 
249 ///
250 unittest
251 {
252 	import std.array : appender;
253 	auto result_app = appender!(ubyte[]);
254 	result_app.stretchPut([0, 1, 2], 7);
255 	assert (result_app[] == [0, 1, 2, 0, 1, 2, 0]);
256 }
257 
258 /**
259 	Decompose a permutation into a series of swaps (0 i0) . (1 i1) . (2 i2) ...
260 
261 	The decomposition can be naturally represented by the array [i0, i1, i2...].
262 
263 	E.g., the permutation [2 1 0] decomposes to (0 2) . (1 1) . (2 2), meaning that taking element index 2,
264 	then element #1, then element #0 is equivalent to swapping #0 with #2, then swapping #1 with #1 (no-op),
265 	then swapping #2 with #2 (no-op).  The decomposition would be [2, 1, 2].  The last swap is always a no-op,
266 	so the decomposition can be shortened to [2, 1].
267 
268 	The advantage of the swap decomposition form is that it lets the permutation be applied to an array in-place.
269 */
270 size_t[] permSwapDecomposition(size_t[] perm) pure
271 {
272 	import std.algorithm.mutation : swap;
273 	import std.array : array;
274 	import std.range : iota;
275 	if (perm.length < 2) return [];
276 	// We calculate the decomposition by incrementally running the permutation on an array v.
277 	// v_idx_of[x] keeps track of the index of the value x in v, so if we need to bring x to the current position, j,
278 	// that's a swap of v[j] with v[v_idx_of[x]], so v_idx_of[x] gets added to the decomposition.
279 	// We can keep v_idx_of and v in sync by making appropriate swaps.
280 	auto v = iota(perm.length).array;
281 	auto v_idx_of = iota(perm.length).array;
282 	auto ret = new size_t[perm.length-1];
283 	foreach (j; 0..ret.length)
284 	{
285 		const x = perm[j];
286 		ret[j] = v_idx_of[x];
287 		const old = v[j];
288 		swap(v[j], v[v_idx_of[x]]);
289 		swap(v_idx_of[x], v_idx_of[old]);
290 	}
291 	return ret;
292 }
293 
294 unittest
295 {
296 	assert (permSwapDecomposition([]) == []);
297 	assert (permSwapDecomposition([2, 1, 0]) == [2, 1]);
298 
299 	// Test that the decomposition does what it's supposed to
300 	// I.e., make a list of swaps that's equivalent to the original permutation
301 	bool test(size_t[] perm)
302 	{
303 		import std.algorithm.mutation : swap;
304 		import std.array : array;
305 		import std.range : iota;
306 		auto decomp = permSwapDecomposition(perm);
307 		auto v = iota(perm.length).array;
308 		foreach (j; 0..decomp.length)
309 		{
310 			swap(v[j], v[decomp[j]]);
311 		}
312 		return v == perm;
313 	}
314 
315 	assert (test([0]));
316 	assert (test([0, 1]));
317 	assert (test([1, 0]));
318 
319 	assert (test([1, 0, 4, 2, 3,]));
320 	assert (test([2, 1, 4, 3, 0,]));
321 	assert (test([0, 4, 3, 2, 1,]));
322 	assert (test([1, 2, 0, 3, 4,]));
323 	assert (test([0, 2, 3, 4, 1,]));
324 	assert (test([4, 3, 2, 0, 1,]));
325 
326 	assert (test([4, 3, 0, 2, 1, 5,]));
327 	assert (test([4, 3, 5, 0, 1, 2,]));
328 	assert (test([2, 0, 1, 5, 4, 3,]));
329 	assert (test([4, 5, 0, 2, 1, 3,]));
330 	assert (test([0, 1, 3, 5, 2, 4,]));
331 	assert (test([4, 2, 1, 3, 0, 5,]));
332 	assert (test([5, 2, 4, 1, 0, 3,]));
333 }
334 
335 private:
336 
337 /// libbsd's portable and robust cryptographic randomness generator
338 extern(C) void arc4random_buf(void *buf, size_t nbytes) @nogc nothrow;
339 
340 immutable crypt_b64_tab = "./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
341 
342 /// Decode a single crypt(3) base 64 digit to its numerical value
343 ubyte cryptB64DecodeChar(char c) pure
344 {
345 	switch (c)
346 	{
347 		case 'a': .. case 'z':
348 			return cast(ubyte)(c - 'a' + 38);
349 		case 'A': .. case 'Z':
350 			return cast(ubyte)(c - 'A' + 12);
351 		static assert ('.' + 1 == '/' && '/' + 1 == '0');
352 		case '.': .. case '9':
353 			return cast(ubyte)(c - '.');
354 		default:
355 			throw new ValueException("Invalid (crypt) base 64 value");
356 	}
357 }
358 
359 unittest
360 {
361 	foreach (j; 0..64)
362 	{
363 		assert (j == cryptB64DecodeChar(crypt_b64_tab[j]));
364 	}
365 }
366 
367 /**
368 	Encode an integer to `num_chars` crypt(3) base 64 digits, writing to an output range
369 
370 	Note: the caller is responsible for setting the right `num_chars` to output the value correctly.
371 */
372 void cryptB64Chars(Out)(uint v, ref Out output, size_t num_chars) if (isOutputRange!(Out, char))
373 {
374 	while (num_chars--)
375 	{
376 		output.put(crypt_b64_tab[v & 0x3f]);
377 		v >>= 6;
378 	}
379 	assert (v == 0);
380 }