import asyncio import json import re import urllib.parse from json.decoder import JSONDecodeError from pathlib import Path from typing import Any, Awaitable, Callable, Optional from .responses import HTTPResponse, JSONResponse, Response PR = re.compile(r"\<([a-zA-Z_][a-zA-Z0-9_]*)\>") class CORS: Origins: set[str] Methods: set[str] Disabled: bool def __init__(self): self.Disabled = False self.Origins: set[str] = {} self.Methods: set[str] = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"} class Headers: def __init__(self): self._d: dict[str, str] = {} def get(self, key: str, default: Optional[Any] = None) -> str | Any: return self._d.get(key.lower(), default) def set(self, key: str, value: str) -> None: self._d[key.lower()] = value def __str__(self): return str(self._d) def __contains__(self, key): return self._d.__contains__(key) class Request: def __init__(self, method: str, path: str, headers: Headers, body: bytes): self.method = method self.path = path self.headers = headers self.body = body def __str__(self): return str( { "method": self.method, "path": self.path, "headers": self.headers, "body": self.body, } ) def json(self) -> dict | None: try: return json.loads(self.body) except JSONDecodeError: return None async def _default_404_route(request: Request): return "HTTP/1.1 404 Not Found\r\nContent-Type: text/html\r\n\r\n404 Not Found contact admin".encode( encoding="utf-8" ) async def _default_405_route(request: Request): return "HTTP/1.1 405 Method Not Allowed\r\nContent-Type: text/html\r\n\r\n405 Method Not Allowed".encode( encoding="utf-8" ) class App: def __init__(self): self.routes: dict[ re.Pattern[str], dict[str, Callable[[Request, ...], Awaitable[Response]]] ] = {} self.error_routes: dict[int, Callable[[Request, ...], Awaitable[Response]]] = { 404: _default_404_route, 405: _default_405_route, } self.CORS = CORS() def _pattern_to_regex(self, temp) -> re.Pattern[str]: re_temp = temp iter = PR.finditer(temp) for m in iter: name = m[1] re_temp = re.sub(m[0], r"(?P<" + name + r">[a-zA-Z0-9\-._~:%&=]+)", re_temp) return re.compile(re_temp) def _serve( self, path: str, method: str, func: Callable[[Request, ...], Awaitable[Response]], ): if method not in ["GET", "POST", "PUT", "DELETE"]: raise RuntimeError(f'Invalid method "{method}".') pat = self._pattern_to_regex(path) if pat not in self.routes: self.routes[pat] = {} if method in self.routes[pat]: raise RuntimeWarning(f'Path "{path}" already exists.') self.routes[pat][method] = func def GET(self, path: str): """Decorator to register a GET HTTP route.""" def decorator(func: Callable[[Request, ...], Awaitable[Response]]): self._serve(path, "GET", func) async def wrapper(*args, **kwargs): res: Response = await func(*args, **kwargs) res.no_header = True return res self._serve(path, "HEAD", wrapper) return func return decorator def POST(self, path: str): """Decorator to register a POST HTTP route.""" def decorator(func: Callable[[Request, ...], Awaitable[Response]]): self._serve(path, "POST", func) return func return decorator def PUT(self, path: str): """Decorator to register a PUT HTTP route.""" def decorator(func: Callable[[Request, ...], Awaitable[Response]]): self._serve(path, "PUT", func) return func def DELETE(self, path: str): """Decorator to register a DELETE HTTP route.""" def decorator(func: Callable[[Request, ...], Awaitable[Response]]): self._serve(path, "DELETE", func) return func return decorator def PATCH(self, path: str): """Decorator to register a PATCH HTTP route.""" def decorator(func: Callable[[Request, ...], Awaitable[Response]]): self._serve(path, "PATCH", func) return func return decorator def HEAD(self, path: str): """Decorator to register a HEAD HTTP route.""" def decorator(func: Callable[[Request, ...], Awaitable[Response]]): self._serve(path, "HEAD", func) return func return decorator def error(self, code): """Decorator to register an error route.""" def decorator(func): self.error_routes[code] = func return func return decorator def resolve(self, path, method) -> tuple[Callable[..., Awaitable[Response]], dict]: for pattern, route in self.routes.items(): if m := pattern.fullmatch(path): if method not in route: return self.error_routes[405], {} return route[method], { k: urllib.parse.unquote(v) for k, v in m.groupdict().items() } return self.error_routes[404], {} def methods(self, path) -> set[str]: for pattern, route in self.routes.items(): if pattern.fullmatch(path): return set(route.keys()) return set() async def handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ): """Handle an incoming connection (HTTP or WebSocket).""" try: # Read the initial HTTP request line request_line = await reader.readline() if not request_line: return # Parse request line parts = request_line.decode(encoding="utf-8").strip().split() if len(parts) < 3: return method, path, protocol = parts[0], parts[1], parts[2] assert protocol == "HTTP/1.1" headers: Headers = Headers() while True: line = await reader.readline() if line == b"\r\n" or line == b"\n" or not line: # End of headers break line = line.decode("utf-8").strip() if ":" in line: key, value = line.split(":", 1) headers.set(key.strip(), value.strip()) content_length = int(headers.get("Content-Length", 0)) body = await reader.read(content_length) if content_length else b"" if ( method == "OPTIONS" ): # FIXME Should handle responding with available methods as well if "origin" in headers and headers.get("origin") in self.CORS.Origins: origin = headers.get("origin") head = [ "Content-Type: text/plain", "Content-Length: 0", f"Access-Control-Allow-Origin: {origin}", f"Access-Control-Allow-Methods: {','.join(self.CORS.Methods)}", "Access-Control-Allow-Headers: Content-Type,Authorization", # CORS "Vary: Origin", ] writer.write(Response(200, head, b"").render(self)) await writer.drain() else: writer.write(Response(403, ["Vary: Origin"], b"").render(self)) await writer.drain() else: route, kwargs = self.resolve(path, method) response: Response = await route( request=Request( method=method, path=path, headers=headers, body=body ), **kwargs, ) response.headers.append("Vary: Origin") if ( "origin" in headers and self.CORS.Disabled and headers.get("origin") in self.CORS.Origins ): # CORS response.headers.append( f"Access-Control-Allow-Origin: {headers.get('origin')}" ) response.headers.append( f"Access-Control-Allow-Methods: {','.join(self.CORS.Methods & self.methods(path))}" ) response.headers.append( "Access-Control-Allow-Headers: Content-Type,Authorization\r\n" ) writer.write(response.render(self)) await writer.drain() except Exception as e: raise e print(f"Internal Server Error: {e}") finally: writer.close() await writer.wait_closed() async def run(self, host="127.0.0.1", port=8000): """Start the async server.""" server = await asyncio.start_server(self.handle_client, host, port) print(f"Serving on http://{host}:{port}") async with server: await server.serve_forever() _value_pattern = re.compile(r"\{\{\s*([a-zA-Z_][a-zA-Z_0-9]*)\s*\}\}") def render( file: str | Path, variables: dict[str, Any] = {} ) -> Response: # TODO Move to another module if isinstance(file, str): file = Path(file) content: str = file.read_text(encoding="utf-8") for m in _value_pattern.findall(content): if m in variables: content = re.sub(r"\{\{\s*" + m + r"\s*\}\}", variables[m], content) return HTTPResponse(content, content_type="text/html; charset=utf-8") def redirect(location: str): # TODO Move to another module return Response(307, ["Location: {location}"], b"") def JSONAPI(func): # TODO Move to another module async def wrapper(*args, **kwargs): result = await func(*args, **kwargs) if not isinstance(result, dict): if ( isinstance(result, tuple) and len(result) == 2 and isinstance(result[1], dict) and isinstance(result[0], int) ): return JSONResponse(result[1], result[0]) raise RuntimeError("Return value of JSONAPI route is not a dictionary") return JSONResponse(result) return wrapper