mirror of
https://github.com/PiBrewing/craftbeerpi4.git
synced 2024-12-23 22:14:56 +01:00
175 lines
5.8 KiB
Python
175 lines
5.8 KiB
Python
import abc
|
|
import time
|
|
from ipaddress import ip_address
|
|
from ticket_auth import TicketFactory, TicketError
|
|
from .abstract_auth import AbstractAuthentication
|
|
from aiohttp import web
|
|
|
|
|
|
_REISSUE_KEY = 'aiohttp_auth.auth.TktAuthentication.reissue'
|
|
|
|
|
|
class TktAuthentication(AbstractAuthentication):
|
|
"""Ticket authentication mechanism based on the ticket_auth library.
|
|
|
|
This class is an abstract class that creates a ticket and validates it.
|
|
Storage of the ticket data itself is abstracted to allow different
|
|
implementations to store the cookie differently (encrypted, server side
|
|
etc).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
secret,
|
|
max_age,
|
|
reissue_time=None,
|
|
include_ip=False,
|
|
cookie_name='AUTH_TKT'):
|
|
"""Initializes the ticket authentication mechanism.
|
|
|
|
Args:
|
|
secret: Byte sequence used to initialize the ticket factory.
|
|
max_age: Integer representing the number of seconds to allow the
|
|
ticket to remain valid for after being issued.
|
|
reissue_time: Integer representing the number of seconds before
|
|
a valid login will cause a ticket to be reissued. If this
|
|
value is 0, a new ticket will be reissued on every request
|
|
which requires authentication. If this value is None, no
|
|
tickets will be reissued, and the max_age will always expire
|
|
the ticket.
|
|
include_ip: If true, requires the clients ip details when
|
|
calculating the ticket hash
|
|
cookie_name: Name to use to reference the ticket details.
|
|
"""
|
|
self._ticket = TicketFactory(secret)
|
|
self._max_age = max_age
|
|
if (self._max_age is not None and
|
|
reissue_time is not None and
|
|
reissue_time < self._max_age):
|
|
self._reissue_time = max_age - reissue_time
|
|
else:
|
|
self._reissue_time = None
|
|
|
|
self._include_ip = include_ip
|
|
self._cookie_name = cookie_name
|
|
|
|
@property
|
|
def cookie_name(self):
|
|
"""Returns the name of the cookie stored in the session"""
|
|
return self._cookie_name
|
|
|
|
async def remember(self, request, user_id):
|
|
"""Called to store the userid for a request.
|
|
|
|
This function creates a ticket from the request and user_id, and calls
|
|
the abstract function remember_ticket() to store the ticket.
|
|
|
|
Args:
|
|
request: aiohttp Request object.
|
|
user_id: String representing the user_id to remember
|
|
"""
|
|
ticket = self._new_ticket(request, user_id)
|
|
await self.remember_ticket(request, ticket)
|
|
|
|
async def forget(self, request):
|
|
"""Called to forget the userid for a request
|
|
|
|
This function calls the forget_ticket() function to forget the ticket
|
|
associated with this request.
|
|
|
|
Args:
|
|
request: aiohttp Request object
|
|
"""
|
|
await self.forget_ticket(request)
|
|
|
|
async def get(self, request):
|
|
"""Gets the user_id for the request.
|
|
|
|
Gets the ticket for the request using the get_ticket() function, and
|
|
authenticates the ticket.
|
|
|
|
Args:
|
|
request: aiohttp Request object.
|
|
|
|
Returns:
|
|
The userid for the request, or None if the ticket is not
|
|
authenticated.
|
|
"""
|
|
ticket = await self.get_ticket(request)
|
|
if ticket is None:
|
|
return None
|
|
|
|
try:
|
|
# Returns a tuple of (user_id, token, userdata, validuntil)
|
|
now = time.time()
|
|
fields = self._ticket.validate(ticket, self._get_ip(request), now)
|
|
|
|
# Check if we need to reissue a ticket
|
|
if (self._reissue_time is not None and
|
|
now >= (fields.valid_until - self._reissue_time)):
|
|
|
|
# Reissue our ticket, and save it in our request.
|
|
request[_REISSUE_KEY] = self._new_ticket(request, fields.user_id)
|
|
|
|
return fields.user_id
|
|
|
|
except TicketError as e:
|
|
return None
|
|
|
|
async def process_response(self, request, response):
|
|
"""If a reissue was requested, only reiisue if the response was a
|
|
valid 2xx response
|
|
"""
|
|
if _REISSUE_KEY in request:
|
|
if (response.started or
|
|
not isinstance(response, web.Response) or
|
|
response.status < 200 or response.status > 299):
|
|
return
|
|
|
|
await self.remember_ticket(request, request[_REISSUE_KEY])
|
|
|
|
@abc.abstractmethod
|
|
async def remember_ticket(self, request, ticket):
|
|
"""Abstract function called to store the ticket data for a request.
|
|
|
|
Args:
|
|
request: aiohttp Request object.
|
|
ticket: String like object representing the ticket to be stored.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
async def forget_ticket(self, request):
|
|
"""Abstract function called to forget the ticket data for a request.
|
|
|
|
Args:
|
|
request: aiohttp Request object.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
async def get_ticket(self, request):
|
|
"""Abstract function called to return the ticket for a request.
|
|
|
|
Args:
|
|
request: aiohttp Request object.
|
|
|
|
Returns:
|
|
A ticket (string like) object, or None if no ticket is available
|
|
for the passed request.
|
|
"""
|
|
pass
|
|
|
|
def _get_ip(self, request):
|
|
ip = None
|
|
if self._include_ip:
|
|
peername = request.transport.get_extra_info('peername')
|
|
if peername:
|
|
ip = ip_address(peername[0])
|
|
|
|
return ip
|
|
|
|
def _new_ticket(self, request, user_id):
|
|
ip = self._get_ip(request)
|
|
valid_until = int(time.time()) + self._max_age
|
|
return self._ticket.new(user_id, valid_until=valid_until, client_ip=ip)
|