Files

100 lines
3.4 KiB
Python

from collections import defaultdict
from datetime import date, datetime
from typing import Iterable
import requests
from regiojet_search.cache import cache
from regiojet_search.db import fetch_results, save_results
from regiojet_search.models import City, Country, Fare, Result
from regiojet_search.settings import LOCATIONS_URL, SEARCH_URL
def fetch_cities() -> list[City]:
cache_key = "cities"
cached_cities = cache.get(cache_key)
if cached_cities is not None:
return [City(**cached_city) for cached_city in cached_cities]
cities: list[City] = []
response = requests.get(LOCATIONS_URL)
for raw_country in response.json():
country = Country(code=raw_country["code"], name=raw_country["country"])
cities.extend(
[City(id=raw_city["id"], name=raw_city["name"], country=country) for raw_city in raw_country["cities"]]
)
cache.set(cache_key, cities)
return cities
def search(from_city: City, to_city: City, departure_date: date, currency: str) -> list[Result]:
slug = Result.slugify(from_city.name, to_city.name, departure_date, currency)
cache_key = f"journey:{slug}"
cached_results = cache.get(cache_key)
if cached_results is not None:
return [Result(**cached_result) for cached_result in cached_results]
if results_from_db := fetch_results(slug):
persist_results(results_from_db, save_to_database=False)
return results_from_db
response = requests.get(
SEARCH_URL,
headers={"X-Currency": currency.upper()},
params={
"tariffs": "REGULAR",
"fromLocationType": "CITY",
"toLocationType": "CITY",
"fromLocationId": from_city.id,
"toLocationId": to_city.id,
"departureDate": departure_date.isoformat(),
},
)
results = list(parse_results(response.json()["routes"], from_city, to_city, currency=currency))
persist_results(results)
return discard_different_departure_dates(results, departure_date)
def parse_results(raw_results: list[dict], from_city: City, to_city: City, currency: str) -> Iterable[Result]:
for raw_result in raw_results:
yield Result(
departure=datetime.fromisoformat(raw_result["departureTime"]),
arrival=datetime.fromisoformat(raw_result["arrivalTime"]),
origin=from_city.name,
destination=to_city.name,
fare=Fare(
amount=raw_result["priceFrom"],
currency=currency,
),
type=raw_result["vehicleTypes"][0].lower(),
source_id=raw_result["departureStationId"],
destination_id=raw_result["arrivalStationId"],
free_seats=raw_result["freeSeatsCount"],
carrier=raw_result["vehicleStandardKey"],
)
def discard_different_departure_dates(results: list[Result], departure: date) -> list[Result]:
return [result for result in results if result.departure.date() == departure]
def persist_results(results: list[Result], save_to_redis: bool = True, save_to_database: bool = True) -> None:
buckets: dict[str, list[Result]] = defaultdict(list)
for result in results:
buckets[result.slug].append(result)
for slug, subset_of_results in buckets.items():
if save_to_database:
save_results(subset_of_results)
if save_to_redis:
cache.set(slug, subset_of_results)