#!/usr/bin/env python
# @@@ START COPYRIGHT @@@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
# @@@ END COPYRIGHT @@@

from subprocess import call
import sys, os, subprocess, tempfile

argList = sys.argv
execNode = 0

def isDev():
	mySQROOT = ''
	try:
		mySQROOT = os.environ['TRAF_HOME']
	except:
		print "ERROR: $TRAF_HOME variable not set.  Please setup SeaQuest environment."
		exit(1)

	sqconfig = "%s/sql/scripts/sqconfig" % mySQROOT
	grepCommand = ['grep', 'node-name=.[^A-Za-z].[0-9]*', sqconfig]

	proc = subprocess.Popen(grepCommand, stdout = subprocess.PIPE, shell=False)
	grepOut, _ = proc.communicate()

	if grepOut == '':
		return True
	return False

def getNextNode():
	nodeList = []
	process = False
	nodeLine = ''
	nodeInfo = ['sqshell','-c','node','info']
        file = tempfile.TemporaryFile()
        proc = subprocess.Popen(nodeInfo,
                                stdout=file,
                                stderr=subprocess.PIPE,
                                shell=False,
                                close_fds=True)
        proc.wait()
        file.seek(0)
        cmdOut = file.read()
        file.close()
        for line in cmdOut.split('\n'):
                if " " in line:
                        param, value = line.split(" ",1)
                        if "--- ----------- --------" in line:
                                process = True
                                continue
                        if process:
                                if nodeLine == '':
                                        nodeLine = value
                                else:
                                        nodeLine += value
                                        nodeList.append(nodeLine)
                                        nodeLine = ''
	for line in nodeList:
		nodeValue = line.split()
		if nodeValue[2] == 'Up':
			return int(nodeValue[0])
	print "Error: Unable to find a node in 'Up' state. Exiting"
	exit(1)
        return -1

def parseArgs():
	global argList
	global execNode
	if '-n' in argList:
		i = argList.index('-n')
                j = i + 1
		if i >= len(argList)-1:
			print "Incorrect usage: run_dtmci -n <node ID>"
			sys.exit(1)
		execNode = argList[j]
		del argList[i:j+1]
	if '-h' in argList:
		print "Usage: run_dtmci [-n <node ID>] [dtmci command arguments]"
		sys.exit(0)


def startDTMCI(nodeID):
	callCmd = []
	callCmd.append('dtmci')
	myList = argList
	myList.pop(0)
	if len(myList) > 0:
		callCmd = callCmd + myList

	if nodeID == -1 or nodeID == 0:
		call(callCmd)
	else:
		myTTY = getTTY()
		executeCmd = "echo exec {nid %s, in %s, out %s} dtmci %s" \
				% (nodeID, myTTY, myTTY, myTTY)
		if len(myList) > 0:
			executeCmd = executeCmd + ' ' + ' '.join(myList)
		try:
			call(executeCmd + " | sqshell -a", shell=True)
		except:
			sys.exit(1)

def getTTY():
	proc = subprocess.Popen('tty', stdout = subprocess.PIPE)
	tty, err = proc.communicate()
	return tty.rstrip()

def main():
	parseArgs()
	if isDev():
		runNode = execNode
		if runNode == 0:
			runNode = getNextNode()
		startDTMCI(runNode)
	else:
		startDTMCI(-1)

if __name__ == '__main__':
	main()
