#!/usr/bin/python
"""Exchange SPAKE2 keys and print out the session key.

Assumes symmetric exchange and uses the default SPAKE2 parameters.
"""
import argparse
from binascii import hexlify, unhexlify
import attr
import sys

from spake2 import SPAKE2_A, SPAKE2_B, SPAKE2_Symmetric


def main():
    parser = argparse.ArgumentParser(prog='version_exchange')
    parser.add_argument(
        '--code', dest='code', type=unicode,
        help='Password to use to connect to other side')
    parser.add_argument(
        '--side-id', dest='side_id', type=unicode,
        help='Identifier for this side of the exchange')
    parser.add_argument(
        '--side', dest='side', choices=['A', 'B', 'S'],
        help=('Which side this represents. '
              'Decides whether we use symmetric or asymmetric variant.'))
    parser.add_argument(
        '--other-side-id', dest='other_side_id', type=unicode,
        help=('Identifier for other side of the exchange. '
              'Only necessary for asymmetric variants.'))
    params = parser.parse_args(sys.argv[1:])
    transport = Transport(input_stream=sys.stdin, output_stream=sys.stdout)
    protocol = get_protocol(
        params.code, params.side, params.side_id, params.other_side_id)
    run_exchange(transport, protocol)


def get_protocol(code, side, side_id, other_side_id):
    code = code.encode('utf8')
    side_id = side_id.encode('utf8')
    if side == 'S':
        return SPAKE2_Symmetric(code, idSymmetric=side_id)
    other_side_id = other_side_id.encode('utf8')
    if side == 'A':
        return SPAKE2_A(code, idA=side_id, idB=other_side_id)
    elif side == 'B':
        return SPAKE2_B(code, idA=other_side_id, idB=side_id)
    else:
        raise AssertionError('Invalid side: %r' % (side,))


def run_exchange(transport, protocol):
    # Send the SPAKE2 message
    outbound = protocol.start()
    transport.send_line(hexlify(outbound))

    # Receive SPAKE2 message
    pake_msg = transport.receive_line()
    inbound = unhexlify(pake_msg)
    spake_key = protocol.finish(inbound)
    transport.send_line(hexlify(spake_key))


@attr.s
class Transport(object):
    input_stream = attr.ib()
    output_stream = attr.ib()

    def send_line(self, line):
        self.output_stream.write(line.rstrip().encode('utf8'))
        self.output_stream.write('\n')
        self.output_stream.flush()

    def receive_line(self):
        return self.input_stream.readline().strip().decode('utf8')


if __name__ == '__main__':
    main()