451 lines
13 KiB
Python
451 lines
13 KiB
Python
import asyncio
|
|
import base64
|
|
import datetime
|
|
import importlib.metadata
|
|
import logging
|
|
import os
|
|
import re
|
|
import tempfile
|
|
from pathlib import Path
|
|
from types import TracebackType
|
|
from typing import Annotated, Optional, Self, Type
|
|
|
|
import fastapi
|
|
import httpx
|
|
import pydantic
|
|
import pyrfc6266
|
|
|
|
|
|
__all__ = [
|
|
'app',
|
|
]
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
DIST = importlib.metadata.metadata(__name__)
|
|
|
|
DESCRIPTION_CLEAN_PATTERN = re.compile('[^a-z ]')
|
|
EXCLUDE_DESCRIPTION_WORDS = {
|
|
'a',
|
|
'ach',
|
|
'an',
|
|
'card',
|
|
'debit',
|
|
'pay',
|
|
'payment',
|
|
'purchase',
|
|
'retail',
|
|
'the',
|
|
}
|
|
|
|
ALLOW_RESIGNING_CERTS = os.environ.get('ALLOW_RESIGNING_CERTS') == '1'
|
|
FIREFLY_URL = os.environ.get(
|
|
'FIREFLY_URL',
|
|
'http://firefly-iii',
|
|
)
|
|
MAX_KEY_FILE_SIZE = int(
|
|
os.environ.get(
|
|
'MAX_KEY_FILE_SIZE',
|
|
8192,
|
|
)
|
|
)
|
|
MAX_DOCUMENT_SIZE = int(
|
|
os.environ.get(
|
|
'MAX_DOCUMENT_SIZE',
|
|
50 * 2**20,
|
|
)
|
|
)
|
|
PAPERLESS_URL = os.environ.get(
|
|
'PAPERLESS_URL',
|
|
'http://paperless-ngx',
|
|
)
|
|
|
|
|
|
class SignError(Exception):
|
|
...
|
|
|
|
|
|
class FireflyIIITransactionSplit(pydantic.BaseModel):
|
|
type: str
|
|
date: datetime.datetime
|
|
amount: str
|
|
transaction_journal_id: int
|
|
description: str
|
|
|
|
|
|
class FireflyIIITransaction(pydantic.BaseModel):
|
|
transactions: list[FireflyIIITransactionSplit]
|
|
|
|
|
|
class FireflyIIIWebhook(pydantic.BaseModel):
|
|
content: FireflyIIITransaction
|
|
|
|
|
|
class PaperlessNgxDocument(pydantic.BaseModel):
|
|
id: int
|
|
title: str
|
|
|
|
|
|
class PaperlessNgxSearchResults(pydantic.BaseModel):
|
|
count: int
|
|
next: str | None
|
|
previous: str | None
|
|
results: list[PaperlessNgxDocument]
|
|
|
|
|
|
class SSHKeySignResponse(pydantic.BaseModel):
|
|
success: bool
|
|
errors: Optional[list[str]] = None
|
|
certificates: Optional[dict[str, str]]
|
|
|
|
|
|
class HttpxClientMixin:
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._client: Optional[httpx.AsyncClient] = None
|
|
|
|
async def __aenter__(self) -> Self:
|
|
await self.client.__aenter__()
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: Optional[Type[Exception]],
|
|
exc_value: Optional[Exception],
|
|
tb: Optional[TracebackType],
|
|
) -> None:
|
|
await self.client.__aexit__(exc_type, exc_value, tb)
|
|
|
|
@property
|
|
def client(self) -> httpx.AsyncClient:
|
|
if self._client is None:
|
|
self._client = self._get_client()
|
|
return self._client
|
|
|
|
def _get_client(self) -> httpx.AsyncClient:
|
|
return httpx.AsyncClient(
|
|
headers={
|
|
'User-Agent': f'{DIST["Name"]}/{DIST["Version"]}',
|
|
},
|
|
)
|
|
|
|
|
|
class Firefly(HttpxClientMixin):
|
|
def _get_client(self) -> httpx.AsyncClient:
|
|
client = super()._get_client()
|
|
if token_file := os.environ.get('FIREFLY_AUTH_TOKEN'):
|
|
try:
|
|
f = open(token_file, encoding='utf-8')
|
|
except OSError as e:
|
|
log.error('Could not load Firefly-III access token: %s', e)
|
|
else:
|
|
with f:
|
|
token = f.read().strip()
|
|
client.headers['Authorization'] = f'Bearer {token}'
|
|
return client
|
|
|
|
async def attach_receipt(
|
|
self,
|
|
xact_id: int,
|
|
doc: bytes,
|
|
filename: str,
|
|
title: str | None = None,
|
|
) -> None:
|
|
log.info('Attaching receipt %r to transaction %d', filename, xact_id)
|
|
url = f'{FIREFLY_URL}/api/v1/attachments'
|
|
data = {
|
|
'filename': filename,
|
|
'attachable_type': 'TransactionJournal',
|
|
'attachable_id': xact_id,
|
|
}
|
|
if title:
|
|
data['title'] = title
|
|
r = await self.client.post(url, data=data)
|
|
r.raise_for_status()
|
|
rbody = r.json()
|
|
attachment = rbody['data']
|
|
url = f'{FIREFLY_URL}/api/v1/attachments/{attachment["id"]}/upload'
|
|
r = await self.client.post(
|
|
url,
|
|
content=doc,
|
|
headers={
|
|
'Content-Type': 'application/octet-stream',
|
|
},
|
|
)
|
|
r.raise_for_status()
|
|
|
|
|
|
class Paperless(HttpxClientMixin):
|
|
def _get_client(self) -> httpx.AsyncClient:
|
|
client = super()._get_client()
|
|
if token_file := os.environ.get('PAPERLESS_AUTH_TOKEN'):
|
|
try:
|
|
f = open(token_file, encoding='utf-8')
|
|
except OSError as e:
|
|
log.error(
|
|
'Could not load Paperless-ngx authentication token: %s', e
|
|
)
|
|
else:
|
|
with f:
|
|
token = f.read().strip()
|
|
client.headers['Authorization'] = f'Token {token}'
|
|
return client
|
|
|
|
async def find_receipts(
|
|
self, search: str, amount: float, date: datetime.date
|
|
) -> list[tuple[str, str, bytes]]:
|
|
date_begin = date - datetime.timedelta(days=2)
|
|
date_end = date + datetime.timedelta(days=2)
|
|
query = ' '.join(
|
|
(
|
|
search,
|
|
str(amount),
|
|
'type:Invoice/Receipt',
|
|
f'created:[{date_begin} TO {date_end}]',
|
|
)
|
|
)
|
|
log.info('Searching for receipt in Paperless: %s', query)
|
|
docs: list[tuple[str, str, bytes]] = []
|
|
url = f'{PAPERLESS_URL}/api/documents/'
|
|
r = await self.client.get(url, params={'query': query})
|
|
if r.status_code != 200:
|
|
if log.isEnabledFor(logging.ERROR):
|
|
try:
|
|
data = r.json()
|
|
except ValueError as e:
|
|
log.debug(
|
|
'Failed to parse HTTP error response as JSON: %s', e
|
|
)
|
|
detail = r.text
|
|
else:
|
|
try:
|
|
detail = data['detail']
|
|
except KeyError:
|
|
detail = ''
|
|
log.error(
|
|
'Error searching Paperless: HTTP %d %s: %s',
|
|
r.status_code,
|
|
r.reason_phrase,
|
|
detail,
|
|
)
|
|
return docs
|
|
try:
|
|
data = r.json()
|
|
except ValueError as e:
|
|
log.error('Failed to parse HTTP response as JSON: %s', e)
|
|
return docs
|
|
try:
|
|
results = PaperlessNgxSearchResults.parse_obj(data)
|
|
except pydantic.ValidationError as e:
|
|
log.error('Could not parse search response: %s', e)
|
|
return docs
|
|
log.info('Search returned %d documents', results.count)
|
|
if results.next:
|
|
log.warning(
|
|
'Search returned multiple pages of results; '
|
|
'only the results on the first page are used'
|
|
)
|
|
for doc in results.results:
|
|
url = f'{PAPERLESS_URL}/api/documents/{doc.id}/download/'
|
|
r = await self.client.get(url, params={'original': True})
|
|
if r.status_code != 200:
|
|
log.error(
|
|
'Failed to download document: HTTP %d %s',
|
|
r.status_code,
|
|
r.reason_phrase,
|
|
)
|
|
continue
|
|
try:
|
|
size = int(r.headers['Content-Length'])
|
|
except (KeyError, ValueError) as e:
|
|
log.error(
|
|
'Skipping document ID %d: Cannot determine file size: %s',
|
|
doc.id,
|
|
e,
|
|
)
|
|
continue
|
|
if size > MAX_DOCUMENT_SIZE:
|
|
log.warning(
|
|
'Skipping document ID %d: Size (%d bytes) is greater than '
|
|
'the configured maximum document size (%d bytes)',
|
|
size,
|
|
MAX_DOCUMENT_SIZE,
|
|
)
|
|
continue
|
|
docs.append((response_filename(r), doc.title, await r.aread()))
|
|
return docs
|
|
|
|
|
|
async def handle_firefly_transaction(xact: FireflyIIITransaction) -> None:
|
|
async with Firefly() as ff, Paperless() as pl:
|
|
for split in xact.transactions:
|
|
search = clean_description(split.description)
|
|
try:
|
|
amount = float(split.amount)
|
|
except ValueError as e:
|
|
log.error('Invalid transaction amount: %s', e)
|
|
continue
|
|
for filename, title, doc in await pl.find_receipts(
|
|
search,
|
|
amount,
|
|
split.date.date(),
|
|
):
|
|
try:
|
|
await ff.attach_receipt(
|
|
split.transaction_journal_id, doc, filename, title
|
|
)
|
|
except Exception as e:
|
|
log.error(
|
|
'Failed to attach receipt to transaction ID %d: %s',
|
|
split.transaction_journal_id,
|
|
e,
|
|
)
|
|
|
|
|
|
async def check_host(hostname: str) -> bool:
|
|
cmd = ['step', 'ssh', 'check-host', hostname]
|
|
p = await asyncio.create_subprocess_exec(*cmd)
|
|
return await p.wait() == 0
|
|
|
|
|
|
def clean_description(text: str) -> str:
|
|
matches = DESCRIPTION_CLEAN_PATTERN.sub('', text.lower())
|
|
if not matches:
|
|
log.warning(
|
|
'Failed to clean transaction description: '
|
|
'text did not match regular expression pattern'
|
|
)
|
|
return text
|
|
match_tokens = set(matches.split())
|
|
terms = match_tokens - EXCLUDE_DESCRIPTION_WORDS
|
|
return ' '.join(terms)
|
|
|
|
|
|
def response_filename(response: httpx.Response) -> str:
|
|
if cd := response.headers.get('Content-Disposition'):
|
|
__, params = pyrfc6266.parse(cd)
|
|
maybename = ''
|
|
for p in params:
|
|
if p.name == 'filename*':
|
|
return p.value
|
|
if p.name == 'filename':
|
|
maybename = p.value
|
|
if maybename:
|
|
if maybename.startswith("b'") and maybename.endswith("'"):
|
|
maybename = maybename[2:-1]
|
|
return maybename
|
|
return response.url.path.rstrip('/').rsplit('/', 1)[-1]
|
|
|
|
|
|
async def sign_key(hostname, path: Path) -> tuple[str, str]:
|
|
cmd = ['step', 'ssh', 'certificate', '--sign', '--host', hostname, path]
|
|
p = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdin=asyncio.subprocess.DEVNULL,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.STDOUT,
|
|
)
|
|
p_log = log.getChild('step')
|
|
assert p.stdout
|
|
buf = bytearray()
|
|
while line := await p.stdout.readline():
|
|
buf += line
|
|
p_log.info(line.rstrip().decode('utf-8', 'replace'))
|
|
rc = await p.wait()
|
|
if rc != 0:
|
|
raise SignError(
|
|
f'Signing failed: process returned exit code {rc}: '
|
|
f'{buf.decode("utf-8")}'
|
|
)
|
|
cert_path = path.parent / f'{path.stem}-cert.pub'
|
|
log.info(
|
|
'Successfully signed %s for %s as %s',
|
|
path.name,
|
|
hostname,
|
|
cert_path.name,
|
|
)
|
|
with cert_path.open('r') as f:
|
|
cert = await asyncio.to_thread(f.read)
|
|
return (cert_path.name, cert)
|
|
|
|
|
|
async def sign_uploaded_key(
|
|
hostname: str, f: fastapi.UploadFile
|
|
) -> tuple[str, str]:
|
|
if f.size > MAX_KEY_FILE_SIZE:
|
|
raise SignError(
|
|
f'Refusing to sign key {f.filename}: file too large '
|
|
f'({f.size} bytes, max {MAX_KEY_FILE_SIZE}'
|
|
)
|
|
with tempfile.TemporaryDirectory() as t:
|
|
path = Path(t) / f.filename
|
|
with path.open('wb') as o:
|
|
d = await f.read(MAX_KEY_FILE_SIZE)
|
|
if f.headers.get('Content-Transfer-Encoding') == 'base64':
|
|
d = base64.b64decode(d)
|
|
await asyncio.to_thread(o.write, d)
|
|
return await sign_key(hostname, path)
|
|
|
|
|
|
app = fastapi.FastAPI(
|
|
name=DIST['Name'],
|
|
version=DIST['Version'],
|
|
docs_url='/api-doc/',
|
|
)
|
|
|
|
|
|
@app.on_event('startup')
|
|
def on_start() -> None:
|
|
log.setLevel(logging.DEBUG)
|
|
h = logging.StreamHandler()
|
|
h.setLevel(logging.DEBUG)
|
|
log.addHandler(h)
|
|
|
|
|
|
@app.get('/')
|
|
def status() -> str:
|
|
return 'UP'
|
|
|
|
|
|
@app.post('/hooks/firefly-iii/create')
|
|
async def firefly_iii_create(hook: FireflyIIIWebhook) -> None:
|
|
await handle_firefly_transaction(hook.content)
|
|
|
|
|
|
@app.post('/sshkeys/sign', response_model=SSHKeySignResponse)
|
|
async def sign_ssh_keys(
|
|
response: fastapi.Response,
|
|
hostname: Annotated[str, fastapi.Form()],
|
|
keys: list[fastapi.UploadFile],
|
|
):
|
|
errors = []
|
|
certificates = {}
|
|
if '.' not in hostname:
|
|
errors.append(
|
|
f'Cannot sign certificate for Single-label hostname {hostname}'
|
|
)
|
|
if await check_host(hostname):
|
|
msg = f'{hostname} already has a signed certificate'
|
|
if ALLOW_RESIGNING_CERTS:
|
|
log.warning('%s', msg)
|
|
else:
|
|
log.error('%s', msg)
|
|
errors.append(msg)
|
|
if not errors:
|
|
tasks = [sign_uploaded_key(hostname, k) for k in keys]
|
|
for coro in asyncio.as_completed(tasks):
|
|
try:
|
|
name, cert = await coro
|
|
except Exception as e:
|
|
log.error('%s', e)
|
|
errors.append(str(e))
|
|
else:
|
|
certificates[name] = cert
|
|
if errors:
|
|
response.status_code = fastapi.status.HTTP_400_BAD_REQUEST
|
|
return SSHKeySignResponse(
|
|
success=not errors,
|
|
errors=errors or None,
|
|
certificates=certificates or None,
|
|
)
|