#!/usr/bin/env python3

"""
This application is a stress tester for libjio. It's not a traditional stress
test like fsx (which can be used to test libjio using the preloading library),
but uses fault injection to check how the library behaves under random
failures.
"""

import sys
import os
import random
import traceback
import libjio

try:
	import fiu
except ImportError:
	print()
	print("Error: unable to load fiu module. This test needs libfiu")
	print("support. Please install libfiu and recompile libjio with FI=1.")
	print()
	raise

#
# Auxiliary stuff
#

gbcount = 0
def getbytes(n):
	global gbcount
	gbcount = (gbcount + 1) % 10
	return bytes(str(gbcount) * n, 'ascii')

def randfrange(maxend, maxsize):
	start = random.randint(0, maxend - 1)
	size = random.randint(0, (maxend - 1) - start) % maxsize
	return start, start + size

def randfloat(min, max):
	return min + random.random() % (max - min)

class ConsistencyError (Exception):
	pass

def jfsck(fname):
	try:
		r = libjio.jfsck(fname)
		return r
	except IOError as e:
		if e.args[0] == libjio.J_ENOJOURNAL:
			return { 'total': 0 }
		else:
			raise

def comp_cont(bytes):
	"'aaaabbcc' -> [ ('a', 4), ('b', 2), ('c', 2) ]"
	l = []
	prev = bytes[0]
	c = 1
	for b in bytes[1:]:
		if (b == prev):
			c += 1
			continue

		l.append((prev, c))
		prev = b
		c = 1
	l.append((b, c))
	return l

def pread(fd, start, end):
	ppos = fd.tell()
	fd.seek(start, 0)
	r = bytes()
	c = 0
	total = end - start
	while c < total:
		n = fd.read(total - c)
		if (n == ''):
			break
		c += len(n)
		r += n
	fd.seek(ppos, 0)
	assert c == end - start
	return r

#
# A range of bytes inside a file, used inside the transactions
#
# Note it can't "remember" the fd as it may change between prepare() and
# verify().
#

class Range:
	def __init__(self, fsize, maxlen):
		# public
		self.start, self.end = randfrange(fsize, maxlen)
		self.new_data = None
		self.type = 'r'

		# private
		self.prev_data = None
		self.new_data_ctx = None
		self.read_buf = None

		# read an extended range so we can check we
		# only wrote what we were supposed to
		self.ext_start = max(0, self.start - 32)
		self.ext_end = min(fsize, self.end + 32)

	def overlaps(self, other):
		if (other.ext_start <= self.ext_start <= other.ext_end) or \
		   (other.ext_start <= self.ext_end <= other.ext_end) or \
		   (self.ext_start <= other.ext_start <= self.ext_end) or \
		   (self.ext_start <= other.ext_end <= self.ext_end):
			return True
		return False

	def prepare_r(self):
		self.type = 'r'
		self.read_buf = bytearray(self.end - self.start)

	def verify_r(self, fd):
		real_data = pread(fd, self.start, self.end)
		if real_data != self.read_buf:
			print('Corruption detected')
			self.show(fd)
			raise ConsistencyError

	def prepare_w(self, fd):
		self.type = 'w'
		self.prev_data = pread(fd, self.ext_start, self.ext_end)

		self.new_data = getbytes(self.end - self.start)
		self.new_data_ctx = \
			self.prev_data[:self.start - self.ext_start] \
			+ self.new_data \
			+ self.prev_data[- (self.ext_end - self.end):]

		return self.new_data, self.start

	def verify_w(self, fd):
		# NOTE: fd must be a real file
		real_data = pread(fd, self.ext_start, self.ext_end)
		if real_data not in (self.prev_data, self.new_data_ctx):
			print('Corruption detected')
			self.show(fd)
			raise ConsistencyError

	def verify(self, fd):
		if self.type == 'r':
			self.verify_r(fd)
		else:
			self.verify_w(fd)

	def show(self, fd):
		real_data = pread(fd, self.start, self.end)
		print('Range:', self.ext_start, self.ext_end)
		print('Real:', comp_cont(real_data))
		if self.type == 'w':
			print('Prev:', comp_cont(self.prev_data))
			print('New: ', comp_cont(self.new_data_ctx))
		else:
			print('Buf:', comp_cont(self.read_buf))
		print()


#
# Transactions
#

class T_base:
	"Interface for the transaction types"
	def __init__(self, f, jf, fsize):
		pass

	def prepare(self):
		pass

	def apply(self):
		pass

	def verify(self, write_only = False):
		pass

class T_jwrite (T_base):
	def __init__(self, f, jf, fsize):
		self.f = f
		self.jf = jf
		self.fsize = fsize

		self.maxoplen = min(int(fsize / 256), 16 * 1024 * 1024)
		self.range = Range(self.fsize, self.maxoplen)

	def prepare(self):
		self.range.prepare_w(self.f)

	def apply(self):
		self.jf.pwrite(self.range.new_data, self.range.start)

	def verify(self, write_only = False):
		self.range.verify(self.f)

class T_writeonly (T_base):
	def __init__(self, f, jf, fsize):
		self.f = f
		self.jf = jf
		self.fsize = fsize

		# favour many small ops
		self.maxoplen = 1 * 1024 * 1024
		self.nops = random.randint(1, 26)

		self.ranges = []

		c = 0
		while len(self.ranges) < self.nops and c < self.nops * 1.25:
			candidate = Range(self.fsize, self.maxoplen)
			safe = True
			for r in self.ranges:
				if candidate.overlaps(r):
					safe = False
					break
			if safe:
				self.ranges.append(candidate)
			c += 1

	def prepare(self):
		for r in self.ranges:
			r.prepare_w(self.f)

	def apply(self):
		t = self.jf.new_trans()
		for r in self.ranges:
			t.add_w(r.new_data, r.start)
		t.commit()

	def verify(self, write_only = False):
		try:
			for r in self.ranges:
				r.verify(self.f)
		except ConsistencyError:
			# show context on errors
			print("-" * 50)
			for r in self.ranges:
				r.show(self.f)
			print("-" * 50)
			raise

class T_readwrite (T_writeonly):
	def __init__(self, f, jf, fsize):
		T_writeonly.__init__(self, f, jf, fsize)
		self.read_ranges = []

	def prepare(self):
		for r in self.ranges:
			if random.choice((True, False)):
				r.prepare_w(self.f)
			else:
				r.prepare_r()

	def apply(self):
		t = self.jf.new_trans()
		for r in self.ranges:
			if r.type == 'r':
				t.add_r(r.read_buf, r.start)
			else:
				t.add_w(r.new_data, r.start)
		t.commit()

	def verify(self, write_only = False):
		try:
			for r in self.ranges:
				if write_only and r.type == 'r':
					continue
				r.verify(self.f)
		except ConsistencyError:
			# show context on errors
			print("-" * 50)
			for r in self.ranges:
				r.show(self.f)
			print("-" * 50)
			raise

t_list = [T_jwrite, T_writeonly, T_readwrite]


#
# The test itself
#

class Stresser:
	def __init__(self, fname, fsize, nops, use_fi, use_as):
		self.fname = fname
		self.fsize = fsize
		self.nops = nops
		self.use_fi = use_fi
		self.use_as = use_as

		jflags = 0
		if use_as:
			jflags = libjio.J_LINGER

		self.jf = libjio.open(fname, libjio.O_RDWR | libjio.O_CREAT,
				0o600, jflags)
		self.f = open(fname, mode = 'rb')

		self.jf.truncate(fsize)

		if use_as:
			self.jf.autosync_start(5, 10 * 1024 * 1024)

	def apply(self, trans):
		trans.prepare()
		trans.apply()
		trans.verify()
		return True

	def apply_fork(self, trans):
		# do the prep before the fork so we can verify() afterwards
		trans.prepare()

		sys.stdout.flush()
		pid = os.fork()
		if pid == 0:
			# child
			try:
				self.fiu_enable()
				trans.apply()
				self.fiu_disable()
			except (IOError, MemoryError):
				try:
					self.reopen(trans)
				except (IOError, MemoryError):
					pass
				except:
					self.fiu_disable()
					traceback.print_exc()
				self.fiu_disable()
				sys.exit(1)
			except MemoryError:
				self.fiu_disable()
				sys.exit(1)
			except:
				self.fiu_disable()
				traceback.print_exc()
				sys.exit(1)
			trans.verify()
			sys.exit(0)
		else:
			# parent
			id, status = os.waitpid(pid, 0)
			if not os.WIFEXITED(status):
				raise RuntimeError(status)

			if os.WEXITSTATUS(status) != 0:
				return False
			return True

	def reopen(self, trans):
		self.jf = None
		r = jfsck(self.fname)

		trans.verify(write_only = True)

		self.jf = libjio.open(self.fname,
			libjio.O_RDWR | libjio.O_CREAT, 0o600)
		return r

	def fiu_enable(self):
		if not self.use_fi:
			return

		# To improve code coverage, we randomize the probability each
		# time we enable failure points
		fiu.enable_random('jio/*',
				probability = randfloat(0.0005, 0.005))
		fiu.enable_random('linux/*',
				probability = randfloat(0.005, 0.03))
		fiu.enable_random('posix/*',
			probability = randfloat(0.005, 0.03))
		fiu.enable_random('libc/mm/*',
			probability = randfloat(0.003, 0.07))
		fiu.enable_random('libc/str/*',
			probability = randfloat(0.005, 0.07))

	def fiu_disable(self):
		if self.use_fi:
			fiu.disable('libc/mm/*')
			fiu.disable('posix/*')
			fiu.disable('jio/*')
			fiu.disable('linux/*')

	def run(self):
		nfailures = 0
		sys.stdout.write("  ")

		for i in range(1, self.nops + 1):
			sys.stdout.write(".")
			if i % 10 == 0:
				sys.stdout.write(" ")
			if i % 50 == 0:
				sys.stdout.write(" %d\n" % i)
				sys.stdout.write("  ")
			sys.stdout.flush()

			trans = random.choice(t_list)(self.f, self.jf,
					self.fsize)

			if self.use_fi:
				r = self.apply_fork(trans)
			else:
				r = self.apply(trans)
			if not r:
				nfailures += 1
				r = self.reopen(trans)
				trans.verify(write_only = True)

		sys.stdout.write("\n")
		sys.stdout.flush()
		return nfailures


#
# Main
#

def usage():
	print("""
Use: jiostress <file name> <file size in Mb> [<number of operations>]
	[--fi] [--as]

If the number of operations is not provided, the default (500) will be
used.

If the "--fi" option is passed, the test will perform fault injection. This
option conflicts with "--as".

If the "--as" option is passed, lingering transactions will be used, along
with the automatic syncing thread. This option conflicts with "--fi".
""")


def main():
	try:
		fname = sys.argv[1]
		fsize = int(sys.argv[2]) * 1024 * 1024
		nops = 500
		if len(sys.argv) >= 4 and sys.argv[3].isnumeric():
			nops = int(sys.argv[3])

		use_fi = False
		if '--fi' in sys.argv:
			use_fi = True

		use_as = False
		if '--as' in sys.argv:
			use_as = True
	except:
		usage()
		sys.exit(1)

	if use_fi and use_as:
		print("Error: --fi and --as cannot be used together")
		sys.exit(1)

	s = Stresser(fname, fsize, nops, use_fi, use_as)
	print("Running stress test")
	nfailures = s.run()
	del s
	print("Stress test completed")
	print("  %d operations" % nops)
	print("  %d simulated failures" % nfailures)

	r = jfsck(fname)
	print("Final check completed")
	#os.unlink(fname)


if __name__ == '__main__':
	main()

