326 lines
9.7 KiB
Python
326 lines
9.7 KiB
Python
'''Firefox Marionette protocol for asyncio
|
|
|
|
This module provides an asynchronous implementation of the Firefox
|
|
Marionette protocol over TCP sockets.
|
|
|
|
>>> async with Marionette() as mn:
|
|
... await mn.connect()
|
|
... await mn.navigate('https://getfirefox.com/')
|
|
'''
|
|
|
|
import dataclasses
|
|
import abc
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import random
|
|
import socket
|
|
from types import TracebackType
|
|
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
WindowType = Union[Literal['window'], Literal['tab']]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class WindowRect:
|
|
'''Window size and location'''
|
|
|
|
x: int
|
|
'''Position x coordinate in pixels'''
|
|
y: int
|
|
'''Position y coordinate in pixels'''
|
|
height: int
|
|
'''Height in pixels'''
|
|
width: int
|
|
'''Width in pixels'''
|
|
|
|
|
|
class MarionetteException(Exception):
|
|
'''Base class for Marionette errors'''
|
|
|
|
|
|
class _BaseRPC(metaclass=abc.ABCMeta):
|
|
# pylint: disable=too-few-public-methods
|
|
|
|
@abc.abstractmethod
|
|
async def _send_message(self, command: str, **kwargs: Any) -> Any:
|
|
...
|
|
|
|
|
|
class WebDriverBase(_BaseRPC, metaclass=abc.ABCMeta):
|
|
'''WebDriver protocol implementation'''
|
|
|
|
async def close_window(self) -> List[str]:
|
|
'''Close the current window
|
|
|
|
If the last window is closed, the session will be ended.
|
|
|
|
:returns: List of handles of remaining windows
|
|
'''
|
|
|
|
windows: List[str] = await self._send_message('WebDriver:CloseWindow')
|
|
return windows
|
|
|
|
async def fullscreen(self) -> WindowRect:
|
|
'''Enter or exit fullscreen for the current window
|
|
|
|
:returns: New :py:class:`WindowRect` with current size/location
|
|
'''
|
|
|
|
res: WindowRect = await self._send_message(
|
|
'WebDriver:FullscreenWindow'
|
|
)
|
|
return res
|
|
|
|
async def get_title(self) -> str:
|
|
'''Get the current window title'''
|
|
|
|
res = await self._send_message('WebDriver:GetTitle')
|
|
title: str = res['value']
|
|
return title
|
|
|
|
async def get_url(self) -> str:
|
|
'''Get the URL of the current window'''
|
|
|
|
res = await self._send_message('WebDriver:GetCurrentURL')
|
|
url: str = res['value']
|
|
return url
|
|
|
|
async def get_window_rect(self) -> WindowRect:
|
|
'''Get the current window position and dimensions'''
|
|
|
|
res: Dict[str, int] = await self._send_message(
|
|
'WebDriver:GetWindowRect'
|
|
)
|
|
return WindowRect(**res)
|
|
|
|
async def get_window_handles(self) -> List[str]:
|
|
'''Get a list of handles for all open windows'''
|
|
|
|
handles: List[str]
|
|
handles = await self._send_message('WebDriver:GetWindowHandles')
|
|
return handles
|
|
|
|
async def navigate(self, url: str) -> None:
|
|
'''Navigate to the specified location'''
|
|
|
|
await self._send_message('WebDriver:Navigate', url=url)
|
|
|
|
async def new_window(
|
|
self,
|
|
type: Optional[WindowType] = None, # pylint: disable=redefined-builtin
|
|
focus: bool = False,
|
|
private: bool = False,
|
|
) -> str:
|
|
'''Open a new window or tab
|
|
|
|
:param type: Either ``'window'`` or ``'tab'``
|
|
:param focus: Give the new window focus
|
|
:param private: Open a new private window
|
|
'''
|
|
|
|
res = await self._send_message(
|
|
'WebDriver:NewWindow',
|
|
type=type,
|
|
focus=focus,
|
|
private=private,
|
|
)
|
|
handle: str = res['handle']
|
|
return handle
|
|
|
|
async def set_window_rect(
|
|
self,
|
|
x: Optional[int] = None,
|
|
y: Optional[int] = None,
|
|
height: Optional[int] = None,
|
|
width: Optional[int] = None,
|
|
) -> WindowRect:
|
|
'''Resize the current window
|
|
|
|
:param x: Position x coordinate in pixels
|
|
:param y: Position y coordinate in pixels
|
|
:param height: Height in pixels
|
|
:param width: Width in pixels
|
|
'''
|
|
|
|
if (x is None and y is None) and (height is None and width is None):
|
|
raise ValueError('x and y OR height and width need values')
|
|
res: Dict[str, int] = await self._send_message(
|
|
'WebDriver:SetWindowRect',
|
|
x=x,
|
|
y=y,
|
|
height=height,
|
|
width=width,
|
|
)
|
|
return WindowRect(**res)
|
|
|
|
async def refresh(self) -> None:
|
|
'''Refresh the current window content'''
|
|
|
|
await self._send_message('WebDriver:Refresh')
|
|
|
|
async def switch_to_window(self, handle: str, focus: bool = True) -> None:
|
|
'''Switch to the specified window
|
|
|
|
:param handle: Window handle
|
|
:param focus: Give the selected window focus
|
|
'''
|
|
|
|
await self._send_message(
|
|
'WebDriver:SwitchToWindow', handle=handle, focus=focus
|
|
)
|
|
|
|
|
|
class Marionette(WebDriverBase):
|
|
'''Firefox Marionette session
|
|
|
|
This class implements the WebDriver protocol; see
|
|
:py:class:`WebDriverBase` for available methods.
|
|
'''
|
|
|
|
def __init__(self, host: str = 'localhost', port: int = 2828) -> None:
|
|
self.host = host
|
|
'''Socket host'''
|
|
self.port = port
|
|
'''Socket port'''
|
|
self._transport: Optional[asyncio.Transport] = None
|
|
self._waiting: Dict[int, asyncio.Future[Any]] = {}
|
|
self.session: Optional[Dict[str, Any]] = None
|
|
'''Marionette session information'''
|
|
|
|
async def __aenter__(self) -> 'Marionette':
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: Optional[Type[Exception]],
|
|
exc_value: Optional[Exception],
|
|
tb: Optional[TracebackType],
|
|
) -> None:
|
|
await self.close()
|
|
|
|
async def close(self) -> None:
|
|
'''Close the connection'''
|
|
|
|
self._waiting.clear()
|
|
if self._transport is not None:
|
|
self._transport.close()
|
|
self._transport = None
|
|
self.session = None
|
|
|
|
async def connect(self) -> None:
|
|
'''Connect to the Marionette socket and begin a session'''
|
|
|
|
hello = await self._connect()
|
|
log.info(
|
|
'Connected to Marionette server '
|
|
'(protocol: %s, application type: %s)',
|
|
hello.get('marionetteProtocol'),
|
|
hello.get('applicationType'),
|
|
)
|
|
self.session = await self._send_message(
|
|
'WebDriver:NewSession', strictFileInteractibility=True
|
|
)
|
|
|
|
async def _connect(self) -> Dict[str, Any]:
|
|
loop = asyncio.get_running_loop()
|
|
while 1:
|
|
log.info(
|
|
'Connecting to Marionette server %s on port %d',
|
|
self.host,
|
|
self.port,
|
|
)
|
|
try:
|
|
transport, _protocol = await loop.create_connection(
|
|
lambda: _MarionetteProtocol(self),
|
|
host=self.host,
|
|
port=self.port,
|
|
family=socket.AF_INET6,
|
|
)
|
|
fut = self._waiting[-1] = loop.create_future()
|
|
hello: Dict[str, Any] = await fut
|
|
except (OSError, EOFError) as e:
|
|
log.error('Failed to connect to Marionette server: %s', e)
|
|
await asyncio.sleep(1)
|
|
continue
|
|
else:
|
|
self._transport = cast(asyncio.Transport, transport)
|
|
break
|
|
# pyright: reportUnboundVariable=false
|
|
return hello
|
|
|
|
async def _send_message(self, command: str, **kwargs: Any) -> Any:
|
|
assert self._transport
|
|
loop = asyncio.get_running_loop()
|
|
fut = loop.create_future()
|
|
msgid = random.randint(0, 65535)
|
|
msg = json.dumps([0, msgid, command, kwargs or None])
|
|
log.debug('Sending message: %r', msg)
|
|
self._transport.write(f'{len(msg)}:{msg}'.encode('utf-8'))
|
|
self._waiting[msgid] = fut
|
|
return await fut
|
|
|
|
def _message_received(self, data: bytes) -> None:
|
|
try:
|
|
message = json.loads(data)
|
|
except ValueError as e:
|
|
log.error('Got invalid message from Marionette server: %s', e)
|
|
return
|
|
if -1 in self._waiting:
|
|
self._waiting.pop(-1).set_result(message)
|
|
if isinstance(message, list):
|
|
if message[0] != 1:
|
|
log.error('Unsupported message type: %r', message[0])
|
|
return
|
|
try:
|
|
fut = self._waiting.pop(message[1])
|
|
except KeyError:
|
|
log.warning(
|
|
'Nothing waiting for response to message ID %s', message[1]
|
|
)
|
|
return
|
|
if message[2] is not None:
|
|
fut.set_exception(MarionetteException(message[2]))
|
|
else:
|
|
fut.set_result(message[3])
|
|
|
|
def _connection_error(self, exc: Exception) -> None:
|
|
while self._waiting:
|
|
_msgid, fut = self._waiting.popitem()
|
|
fut.set_exception(exc)
|
|
|
|
|
|
class _MarionetteProtocol(asyncio.Protocol):
|
|
def __init__(self, marionette: Marionette) -> None:
|
|
self.marionette = marionette
|
|
self.buf = bytearray()
|
|
|
|
def data_received(self, data: bytes) -> None:
|
|
self.buf += data
|
|
while self.buf:
|
|
try:
|
|
idx = self.buf.index(b':')
|
|
except ValueError:
|
|
return
|
|
length = int(self.buf[:idx])
|
|
end = idx + 1 + length
|
|
if len(self.buf) < end:
|
|
return
|
|
end = idx + 1 + length
|
|
data = self.buf[idx + 1 : end]
|
|
self.buf = self.buf[end:]
|
|
log.debug('Received message: %s', data)
|
|
# pylint: disable=protected-access
|
|
self.marionette._message_received(data)
|
|
|
|
def connection_lost(self, exc: Optional[Exception]) -> None:
|
|
log.error('Connection lost, cancelling outstanding requests')
|
|
if exc is None:
|
|
exc = EOFError()
|
|
# pylint: disable=protected-access
|
|
self.marionette._connection_error(exc)
|