summaryrefslogtreecommitdiff
path: root/network/client.py
blob: 5f454d73b016997b3f3ce17498376a63a5a1b266 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from threading import Lock as Mutex

import dbus
from dbus.mainloop.glib import DBusGMainLoop
from gi.repository import GLib
import logbook
import zmq

from network import avahi
from network.browser import ConsoleServerBrowser

log = logbook.Logger(__name__)

dbus_loop = DBusGMainLoop()
system_bus = dbus.SystemBus(mainloop=dbus_loop)

avahi_server = dbus.Interface(
    system_bus.get_object(avahi.DBUS_NAME, avahi.DBUS_PATH_SERVER),
    avahi.DBUS_INTERFACE_SERVER,
)

def acquire_or_die(mutex, error):
    if not mutex.acquire(False):
        raise RuntimeError(error)

class Discoverer(object):
    # there seems to be a bug in my current zmq version
    # otherwise, just set this to False.
    IPv4_ONLY = True

    def __init__(self, kind):
        super(Discoverer, self).__init__()
        self.kind = '_%s._tcp' % kind
        self.services = {}

        proto = avahi.PROTO_INET if self.IPv4_ONLY else avahi.PROTO_UNSPEC

        self.avahi_browser = dbus.Interface(
            system_bus.get_object(
                avahi.DBUS_NAME,
                avahi_server.ServiceBrowserNew(
                    avahi.IF_UNSPEC, proto,
                    self.kind, '',
                    dbus.UInt32(0),
                )
            ),
            avahi.DBUS_INTERFACE_SERVICE_BROWSER,
        )

        self.avahi_browser.connect_to_signal('ItemNew', self.on_service_discovered)
        self.avahi_browser.connect_to_signal('ItemRemove', self.on_service_removed)

    def on_services_changed(self):
        pass

    def on_service_resolved(
        self,
        interface, protocol,
        name, kind, domain,
        host, aprotocol, address, port,
        text, flags,
    ):
        key = (interface, protocol, name, kind, domain)

        target = 'tcp://%s:%d' % (
            '[' + address + ']'
            if aprotocol == avahi.PROTO_INET6
            else address,
            port,
        )

        self.services[key] = {
            'name': name,
            'domain': domain,
            'host': host,
            'address': address,
            'port': port,
            'text': text,
            'target': target,
        }

        self.on_services_changed()

    def on_service_error(self, key, error):
        log.warning('Error resolving {}: {}', key, error)
        self.services.pop(key, None)
        self.on_services_changed()

    def on_service_discovered(
        self,
        interface, protocol,
        name, kind, domain,
        flags,
    ):
        key = (interface, protocol, name, kind, domain)
        log.debug('Discovered service {}', key)
        if key in self.services:
            log.warning('Discovered a service twice: {}', key)
            return

        self.services[key] = None

        avahi_server.ResolveService(
            interface, protocol, name, kind, domain,
            avahi.PROTO_UNSPEC, dbus.UInt32(0),
            reply_handler=self.on_service_resolved,
            error_handler=lambda error: self.on_service_error(key, error),
        )

    def on_service_removed(
        self,
        interface, protocol,
        name, kind, domain,
        flags,
    ):
        key = (interface, protocol, name, kind, domain)
        self.services.pop(key, None)
        self.on_services_changed()

class Client(Discoverer):
    """
    This is the counterpart to the Server. This client is capable of
    discovering servers on the LAN and connecting to them (or to any other
    server) using the same request/reply protocol (ZeroMQ).

    :param kind: The machine-readable kind of server you want to connect to.
        Currently only used by ``find_server()``.
    """

    def __init__(self, kind):
        super(Client, self).__init__(kind)
        self._send_mutex = Mutex()

        self.socket = None

    def connect(self, target, ctx=None):
        """
        Connect to a server. If you connect to multiple servers, ZeroMQ will
        load-balance between them.

        :param target: A ZeroMQ endpoint (e.g., "tcp://192.168.0.1:12345")
        :param ctx: A ZeroMQ context (optional)
        """
        if ctx is None:
            ctx = zmq.Context.instance()

        self.socket = ctx.socket(zmq.REQ)
        self.socket.connect(target)

    def send(self, data):
        """
        Send a message to the server, wait for the response, and return it.
        Both the sent and received messages are binary strings, not unicode.

        :param data: The data to be sent to the server.
        :rtype: string (the data returned from the server)
        """
        acquire_or_die(
            self._send_mutex,
            "You called send_async(), then send(): "
            "you need to call recv_async() first.",
        )
        try:
            self.socket.send(data)
        except Exception:
            raise
        else:
            return self.socket.recv()
        finally:
            self._send_mutex.release()

    def send_async(self, data):
        """
        Send a message to the server, without waiting for a response.

        Currently this method waits for acknowledgement of the message by a
        server, but this extra latency will be removed in the future.

        In order to get the result of the message, call ``.read_async``.

        :param data: The data to be sent to the server.
        """
        acquire_or_die(
            self._send_mutex,
            "You tried to call send_async() twice. "
            "You need to call recv_async() first.",
        )
        self.socket.send(data)

    def recv_async(self):
        """
        After sending a message to the server with ``.send_async()``, calling
        this method will wait for the response from the server and return it.

        :rtype: string (the data returned from the server)
        """
        return self.socket.recv()

    def on_services_changed(self):
        super(Client, self).on_services_changed()
        services = [
            service
            for service in self.services.values()
            if service is not None
        ]
        self._browser.update(services)

    def find_server(self, browser_cls=ConsoleServerBrowser, connect=True):
        """
        Find a server on the local network (using avahi) and (optionally)
        connect to it.

        :param browser: A Browser class to provide the user interface.
        :param connect: If true, the client will immediately connect to the
            chosen server.
        :rtype: (string) The ZeroMQ endpoint of the chosen server.
        """
        mainloop = GLib.MainLoop()

        self._browser = browser_cls(self)
        endpoint = self._browser.run(mainloop.run)

        if connect and endpoint is not None:
            self.connect(endpoint)

        return endpoint